progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use log::{Level, Log, Metadata, Record};
 10
 11pub struct Progress {
 12    inner: Mutex<ProgressInner>,
 13}
 14
 15struct ProgressInner {
 16    completed: Vec<CompletedTask>,
 17    in_progress: HashMap<String, InProgressTask>,
 18    is_tty: bool,
 19    terminal_width: usize,
 20    max_example_name_len: usize,
 21    status_lines_displayed: usize,
 22    total_examples: usize,
 23    failed_examples: usize,
 24    last_line_is_logging: bool,
 25}
 26
 27#[derive(Clone)]
 28struct InProgressTask {
 29    step: Step,
 30    started_at: Instant,
 31    substatus: Option<String>,
 32    info: Option<(String, InfoStyle)>,
 33}
 34
 35struct CompletedTask {
 36    step: Step,
 37    example_name: String,
 38    duration: Duration,
 39    info: Option<(String, InfoStyle)>,
 40}
 41
 42#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 43pub enum Step {
 44    LoadProject,
 45    Context,
 46    FormatPrompt,
 47    Predict,
 48    Score,
 49    Synthesize,
 50}
 51
 52#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 53pub enum InfoStyle {
 54    Normal,
 55    Warning,
 56}
 57
 58impl Step {
 59    pub fn label(&self) -> &'static str {
 60        match self {
 61            Step::LoadProject => "Load",
 62            Step::Context => "Context",
 63            Step::FormatPrompt => "Format",
 64            Step::Predict => "Predict",
 65            Step::Score => "Score",
 66            Step::Synthesize => "Synthesize",
 67        }
 68    }
 69
 70    fn color_code(&self) -> &'static str {
 71        match self {
 72            Step::LoadProject => "\x1b[33m",
 73            Step::Context => "\x1b[35m",
 74            Step::FormatPrompt => "\x1b[34m",
 75            Step::Predict => "\x1b[32m",
 76            Step::Score => "\x1b[31m",
 77            Step::Synthesize => "\x1b[36m",
 78        }
 79    }
 80}
 81
 82static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 83static LOGGER: ProgressLogger = ProgressLogger;
 84
 85const MARGIN: usize = 4;
 86const MAX_STATUS_LINES: usize = 10;
 87
 88impl Progress {
 89    /// Returns the global Progress instance, initializing it if necessary.
 90    pub fn global() -> Arc<Progress> {
 91        GLOBAL
 92            .get_or_init(|| {
 93                let progress = Arc::new(Self {
 94                    inner: Mutex::new(ProgressInner {
 95                        completed: Vec::new(),
 96                        in_progress: HashMap::new(),
 97                        is_tty: std::io::stderr().is_terminal(),
 98                        terminal_width: get_terminal_width(),
 99                        max_example_name_len: 0,
100                        status_lines_displayed: 0,
101                        total_examples: 0,
102                        failed_examples: 0,
103                        last_line_is_logging: false,
104                    }),
105                });
106                let _ = log::set_logger(&LOGGER);
107                log::set_max_level(log::LevelFilter::Error);
108                progress
109            })
110            .clone()
111    }
112
113    pub fn set_total_examples(&self, total: usize) {
114        let mut inner = self.inner.lock().unwrap();
115        inner.total_examples = total;
116    }
117
118    pub fn increment_failed(&self) {
119        let mut inner = self.inner.lock().unwrap();
120        inner.failed_examples += 1;
121    }
122
123    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
124    /// This should be used for any output that needs to appear above the status lines.
125    fn log(&self, message: &str) {
126        let mut inner = self.inner.lock().unwrap();
127        Self::clear_status_lines(&mut inner);
128
129        if !inner.last_line_is_logging {
130            let reset = "\x1b[0m";
131            let dim = "\x1b[2m";
132            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
133            eprintln!("{dim}{divider}{reset}");
134            inner.last_line_is_logging = true;
135        }
136
137        eprintln!("{}", message);
138    }
139
140    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
141        let mut inner = self.inner.lock().unwrap();
142
143        Self::clear_status_lines(&mut inner);
144
145        inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
146        inner.in_progress.insert(
147            example_name.to_string(),
148            InProgressTask {
149                step,
150                started_at: Instant::now(),
151                substatus: None,
152                info: None,
153            },
154        );
155
156        Self::print_status_lines(&mut inner);
157
158        StepProgress {
159            progress: self.clone(),
160            step,
161            example_name: example_name.to_string(),
162        }
163    }
164
165    fn finish(&self, step: Step, example_name: &str) {
166        let mut inner = self.inner.lock().unwrap();
167
168        let Some(task) = inner.in_progress.remove(example_name) else {
169            return;
170        };
171
172        if task.step == step {
173            inner.completed.push(CompletedTask {
174                step: task.step,
175                example_name: example_name.to_string(),
176                duration: task.started_at.elapsed(),
177                info: task.info,
178            });
179
180            Self::clear_status_lines(&mut inner);
181            Self::print_logging_closing_divider(&mut inner);
182            Self::print_completed(&inner, inner.completed.last().unwrap());
183            Self::print_status_lines(&mut inner);
184        } else {
185            inner.in_progress.insert(example_name.to_string(), task);
186        }
187    }
188
189    fn print_logging_closing_divider(inner: &mut ProgressInner) {
190        if inner.last_line_is_logging {
191            let reset = "\x1b[0m";
192            let dim = "\x1b[2m";
193            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
194            eprintln!("{dim}{divider}{reset}");
195            inner.last_line_is_logging = false;
196        }
197    }
198
199    fn clear_status_lines(inner: &mut ProgressInner) {
200        if inner.is_tty && inner.status_lines_displayed > 0 {
201            // Move up and clear each line we previously displayed
202            for _ in 0..inner.status_lines_displayed {
203                eprint!("\x1b[A\x1b[K");
204            }
205            let _ = std::io::stderr().flush();
206            inner.status_lines_displayed = 0;
207        }
208    }
209
210    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
211        let duration = format_duration(task.duration);
212        let name_width = inner.max_example_name_len;
213
214        if inner.is_tty {
215            let reset = "\x1b[0m";
216            let bold = "\x1b[1m";
217            let dim = "\x1b[2m";
218
219            let yellow = "\x1b[33m";
220            let info_part = task
221                .info
222                .as_ref()
223                .map(|(s, style)| {
224                    if *style == InfoStyle::Warning {
225                        format!("{yellow}{s}{reset}")
226                    } else {
227                        s.to_string()
228                    }
229                })
230                .unwrap_or_default();
231
232            let prefix = format!(
233                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
234                color = task.step.color_code(),
235                label = task.step.label(),
236                name = task.example_name,
237            );
238
239            let duration_with_margin = format!("{duration} ");
240            let padding_needed = inner
241                .terminal_width
242                .saturating_sub(MARGIN)
243                .saturating_sub(duration_with_margin.len())
244                .saturating_sub(strip_ansi_len(&prefix));
245            let padding = " ".repeat(padding_needed);
246
247            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
248        } else {
249            let info_part = task
250                .info
251                .as_ref()
252                .map(|(s, _)| format!(" | {}", s))
253                .unwrap_or_default();
254
255            eprintln!(
256                "{label:>12} {name:<name_width$}{info_part} {duration}",
257                label = task.step.label(),
258                name = task.example_name,
259            );
260        }
261    }
262
263    fn print_status_lines(inner: &mut ProgressInner) {
264        if !inner.is_tty || inner.in_progress.is_empty() {
265            inner.status_lines_displayed = 0;
266            return;
267        }
268
269        let reset = "\x1b[0m";
270        let bold = "\x1b[1m";
271        let dim = "\x1b[2m";
272
273        // Build the done/in-progress/total label
274        let done_count = inner.completed.len();
275        let in_progress_count = inner.in_progress.len();
276        let failed_count = inner.failed_examples;
277
278        let failed_label = if failed_count > 0 {
279            format!(" {} failed ", failed_count)
280        } else {
281            String::new()
282        };
283
284        let range_label = format!(
285            " {}/{}/{} ",
286            done_count, in_progress_count, inner.total_examples
287        );
288
289        // Print a divider line with failed count on left, range label on right
290        let failed_visible_len = strip_ansi_len(&failed_label);
291        let range_visible_len = range_label.len();
292        let middle_divider_len = inner
293            .terminal_width
294            .saturating_sub(MARGIN * 2)
295            .saturating_sub(failed_visible_len)
296            .saturating_sub(range_visible_len);
297        let left_divider = "".repeat(MARGIN);
298        let middle_divider = "".repeat(middle_divider_len);
299        let right_divider = "".repeat(MARGIN);
300        eprintln!(
301            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
302        );
303
304        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
305        tasks.sort_by_key(|(name, _)| *name);
306
307        let total_tasks = tasks.len();
308        let mut lines_printed = 0;
309
310        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
311            let elapsed = format_duration(task.started_at.elapsed());
312            let substatus_part = task
313                .substatus
314                .as_ref()
315                .map(|s| truncate_with_ellipsis(s, 30))
316                .unwrap_or_default();
317
318            let step_label = task.step.label();
319            let step_color = task.step.color_code();
320            let name_width = inner.max_example_name_len;
321
322            let prefix = format!(
323                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
324                name = name,
325            );
326
327            let duration_with_margin = format!("{elapsed} ");
328            let padding_needed = inner
329                .terminal_width
330                .saturating_sub(MARGIN)
331                .saturating_sub(duration_with_margin.len())
332                .saturating_sub(strip_ansi_len(&prefix));
333            let padding = " ".repeat(padding_needed);
334
335            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
336            lines_printed += 1;
337        }
338
339        // Show "+N more" on its own line if there are more tasks
340        if total_tasks > MAX_STATUS_LINES {
341            let remaining = total_tasks - MAX_STATUS_LINES;
342            eprintln!("{:>12} +{remaining} more", "");
343            lines_printed += 1;
344        }
345
346        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
347        let _ = std::io::stderr().flush();
348    }
349
350    pub fn finalize(&self) {
351        let mut inner = self.inner.lock().unwrap();
352        Self::clear_status_lines(&mut inner);
353
354        // Print summary if there were failures
355        if inner.failed_examples > 0 {
356            let total_processed = inner.completed.len() + inner.failed_examples;
357            let percentage = if total_processed > 0 {
358                inner.failed_examples as f64 / total_processed as f64 * 100.0
359            } else {
360                0.0
361            };
362            eprintln!(
363                "\n{} of {} examples failed ({:.1}%)",
364                inner.failed_examples, total_processed, percentage
365            );
366        }
367    }
368}
369
370pub struct StepProgress {
371    progress: Arc<Progress>,
372    step: Step,
373    example_name: String,
374}
375
376impl StepProgress {
377    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
378        let mut inner = self.progress.inner.lock().unwrap();
379        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
380            task.substatus = Some(substatus.into().into_owned());
381            Progress::clear_status_lines(&mut inner);
382            Progress::print_status_lines(&mut inner);
383        }
384    }
385
386    pub fn clear_substatus(&self) {
387        let mut inner = self.progress.inner.lock().unwrap();
388        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
389            task.substatus = None;
390            Progress::clear_status_lines(&mut inner);
391            Progress::print_status_lines(&mut inner);
392        }
393    }
394
395    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
396        let mut inner = self.progress.inner.lock().unwrap();
397        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
398            task.info = Some((info.into(), style));
399        }
400    }
401}
402
403impl Drop for StepProgress {
404    fn drop(&mut self) {
405        self.progress.finish(self.step, &self.example_name);
406    }
407}
408
409struct ProgressLogger;
410
411impl Log for ProgressLogger {
412    fn enabled(&self, metadata: &Metadata) -> bool {
413        metadata.level() <= Level::Info
414    }
415
416    fn log(&self, record: &Record) {
417        if !self.enabled(record.metadata()) {
418            return;
419        }
420
421        let level_color = match record.level() {
422            Level::Error => "\x1b[31m",
423            Level::Warn => "\x1b[33m",
424            Level::Info => "\x1b[32m",
425            Level::Debug => "\x1b[34m",
426            Level::Trace => "\x1b[35m",
427        };
428        let reset = "\x1b[0m";
429        let bold = "\x1b[1m";
430
431        let level_label = match record.level() {
432            Level::Error => "Error",
433            Level::Warn => "Warn",
434            Level::Info => "Info",
435            Level::Debug => "Debug",
436            Level::Trace => "Trace",
437        };
438
439        let message = format!(
440            "{bold}{level_color}{level_label:>12}{reset} {}",
441            record.args()
442        );
443
444        if let Some(progress) = GLOBAL.get() {
445            progress.log(&message);
446        } else {
447            eprintln!("{}", message);
448        }
449    }
450
451    fn flush(&self) {
452        let _ = std::io::stderr().flush();
453    }
454}
455
456#[cfg(unix)]
457fn get_terminal_width() -> usize {
458    unsafe {
459        let mut winsize: libc::winsize = std::mem::zeroed();
460        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
461            && winsize.ws_col > 0
462        {
463            winsize.ws_col as usize
464        } else {
465            80
466        }
467    }
468}
469
470#[cfg(not(unix))]
471fn get_terminal_width() -> usize {
472    80
473}
474
475fn strip_ansi_len(s: &str) -> usize {
476    let mut len = 0;
477    let mut in_escape = false;
478    for c in s.chars() {
479        if c == '\x1b' {
480            in_escape = true;
481        } else if in_escape {
482            if c == 'm' {
483                in_escape = false;
484            }
485        } else {
486            len += 1;
487        }
488    }
489    len
490}
491
492fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
493    if s.len() <= max_len {
494        s.to_string()
495    } else {
496        format!("{}", &s[..max_len.saturating_sub(1)])
497    }
498}
499
500fn format_duration(duration: Duration) -> String {
501    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
502
503    let millis = duration.as_millis() as f32;
504    if millis < 1000.0 {
505        format!("{}ms", millis)
506    } else if millis < MINUTE_IN_MILLIS {
507        format!("{:.1}s", millis / 1_000.0)
508    } else {
509        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
510    }
511}