progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use log::{Level, Log, Metadata, Record};
 10
 11pub struct Progress {
 12    inner: Mutex<ProgressInner>,
 13}
 14
 15struct ProgressInner {
 16    completed: Vec<CompletedTask>,
 17    in_progress: HashMap<String, InProgressTask>,
 18    is_tty: bool,
 19    terminal_width: usize,
 20    max_example_name_len: usize,
 21    status_lines_displayed: usize,
 22    total_examples: usize,
 23    failed_examples: usize,
 24    last_line_is_logging: bool,
 25    ticker: Option<std::thread::JoinHandle<()>>,
 26}
 27
 28#[derive(Clone)]
 29struct InProgressTask {
 30    step: Step,
 31    started_at: Instant,
 32    substatus: Option<String>,
 33    info: Option<(String, InfoStyle)>,
 34}
 35
 36struct CompletedTask {
 37    step: Step,
 38    example_name: String,
 39    duration: Duration,
 40    info: Option<(String, InfoStyle)>,
 41}
 42
 43#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 44pub enum Step {
 45    LoadProject,
 46    Context,
 47    FormatPrompt,
 48    Predict,
 49    Score,
 50    Synthesize,
 51    PullExamples,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 55pub enum InfoStyle {
 56    Normal,
 57    Warning,
 58}
 59
 60impl Step {
 61    pub fn label(&self) -> &'static str {
 62        match self {
 63            Step::LoadProject => "Load",
 64            Step::Context => "Context",
 65            Step::FormatPrompt => "Format",
 66            Step::Predict => "Predict",
 67            Step::Score => "Score",
 68            Step::Synthesize => "Synthesize",
 69            Step::PullExamples => "Pull",
 70        }
 71    }
 72
 73    fn color_code(&self) -> &'static str {
 74        match self {
 75            Step::LoadProject => "\x1b[33m",
 76            Step::Context => "\x1b[35m",
 77            Step::FormatPrompt => "\x1b[34m",
 78            Step::Predict => "\x1b[32m",
 79            Step::Score => "\x1b[31m",
 80            Step::Synthesize => "\x1b[36m",
 81            Step::PullExamples => "\x1b[36m",
 82        }
 83    }
 84}
 85
 86static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 87static LOGGER: ProgressLogger = ProgressLogger;
 88
 89const MARGIN: usize = 4;
 90const MAX_STATUS_LINES: usize = 10;
 91const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 92
 93impl Progress {
 94    /// Returns the global Progress instance, initializing it if necessary.
 95    pub fn global() -> Arc<Progress> {
 96        GLOBAL
 97            .get_or_init(|| {
 98                let progress = Arc::new(Self {
 99                    inner: Mutex::new(ProgressInner {
100                        completed: Vec::new(),
101                        in_progress: HashMap::new(),
102                        is_tty: std::env::var("NO_COLOR").is_err()
103                            && std::io::stderr().is_terminal(),
104                        terminal_width: get_terminal_width(),
105                        max_example_name_len: 0,
106                        status_lines_displayed: 0,
107                        total_examples: 0,
108                        failed_examples: 0,
109                        last_line_is_logging: false,
110                        ticker: None,
111                    }),
112                });
113                let _ = log::set_logger(&LOGGER);
114                log::set_max_level(log::LevelFilter::Info);
115                progress
116            })
117            .clone()
118    }
119
120    pub fn set_total_examples(&self, total: usize) {
121        let mut inner = self.inner.lock().unwrap();
122        inner.total_examples = total;
123    }
124
125    pub fn increment_failed(&self) {
126        let mut inner = self.inner.lock().unwrap();
127        inner.failed_examples += 1;
128    }
129
130    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
131    /// This should be used for any output that needs to appear above the status lines.
132    fn log(&self, message: &str) {
133        let mut inner = self.inner.lock().unwrap();
134        Self::clear_status_lines(&mut inner);
135
136        if !inner.last_line_is_logging {
137            let reset = "\x1b[0m";
138            let dim = "\x1b[2m";
139            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
140            eprintln!("{dim}{divider}{reset}");
141            inner.last_line_is_logging = true;
142        }
143
144        eprintln!("{}", message);
145    }
146
147    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
148        let mut inner = self.inner.lock().unwrap();
149
150        Self::clear_status_lines(&mut inner);
151
152        let max_name_width = inner
153            .terminal_width
154            .saturating_sub(MARGIN * 2)
155            .saturating_div(3)
156            .max(1);
157        inner.max_example_name_len = inner
158            .max_example_name_len
159            .max(example_name.len().min(max_name_width));
160        inner.in_progress.insert(
161            example_name.to_string(),
162            InProgressTask {
163                step,
164                started_at: Instant::now(),
165                substatus: None,
166                info: None,
167            },
168        );
169
170        if inner.is_tty && inner.ticker.is_none() {
171            let progress = self.clone();
172            inner.ticker = Some(std::thread::spawn(move || {
173                loop {
174                    std::thread::sleep(STATUS_TICK_INTERVAL);
175
176                    let mut inner = progress.inner.lock().unwrap();
177                    if inner.in_progress.is_empty() {
178                        break;
179                    }
180
181                    Progress::clear_status_lines(&mut inner);
182                    Progress::print_status_lines(&mut inner);
183                }
184            }));
185        }
186
187        Self::print_status_lines(&mut inner);
188
189        StepProgress {
190            progress: self.clone(),
191            step,
192            example_name: example_name.to_string(),
193        }
194    }
195
196    fn finish(&self, step: Step, example_name: &str) {
197        let mut inner = self.inner.lock().unwrap();
198
199        let Some(task) = inner.in_progress.remove(example_name) else {
200            return;
201        };
202
203        if task.step == step {
204            inner.completed.push(CompletedTask {
205                step: task.step,
206                example_name: example_name.to_string(),
207                duration: task.started_at.elapsed(),
208                info: task.info,
209            });
210
211            Self::clear_status_lines(&mut inner);
212            Self::print_logging_closing_divider(&mut inner);
213            if let Some(last_completed) = inner.completed.last() {
214                Self::print_completed(&inner, last_completed);
215            }
216            Self::print_status_lines(&mut inner);
217        } else {
218            inner.in_progress.insert(example_name.to_string(), task);
219        }
220    }
221
222    fn print_logging_closing_divider(inner: &mut ProgressInner) {
223        if inner.last_line_is_logging {
224            let reset = "\x1b[0m";
225            let dim = "\x1b[2m";
226            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
227            eprintln!("{dim}{divider}{reset}");
228            inner.last_line_is_logging = false;
229        }
230    }
231
232    fn clear_status_lines(inner: &mut ProgressInner) {
233        if inner.is_tty && inner.status_lines_displayed > 0 {
234            // Move up and clear each line we previously displayed
235            for _ in 0..inner.status_lines_displayed {
236                eprint!("\x1b[A\x1b[K");
237            }
238            let _ = std::io::stderr().flush();
239            inner.status_lines_displayed = 0;
240        }
241    }
242
243    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
244        let duration = format_duration(task.duration);
245        let name_width = inner.max_example_name_len;
246        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
247
248        if inner.is_tty {
249            let reset = "\x1b[0m";
250            let bold = "\x1b[1m";
251            let dim = "\x1b[2m";
252
253            let yellow = "\x1b[33m";
254            let info_part = task
255                .info
256                .as_ref()
257                .map(|(s, style)| {
258                    if *style == InfoStyle::Warning {
259                        format!("{yellow}{s}{reset}")
260                    } else {
261                        s.to_string()
262                    }
263                })
264                .unwrap_or_default();
265
266            let prefix = format!(
267                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
268                color = task.step.color_code(),
269                label = task.step.label(),
270                name = truncated_name,
271            );
272
273            let duration_with_margin = format!("{duration} ");
274            let padding_needed = inner
275                .terminal_width
276                .saturating_sub(MARGIN)
277                .saturating_sub(duration_with_margin.len())
278                .saturating_sub(strip_ansi_len(&prefix));
279            let padding = " ".repeat(padding_needed);
280
281            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
282        } else {
283            let info_part = task
284                .info
285                .as_ref()
286                .map(|(s, _)| format!(" | {}", s))
287                .unwrap_or_default();
288
289            eprintln!(
290                "{label:>12} {name:<name_width$}{info_part} {duration}",
291                label = task.step.label(),
292                name = truncate_with_ellipsis(&task.example_name, name_width),
293            );
294        }
295    }
296
297    fn print_status_lines(inner: &mut ProgressInner) {
298        if !inner.is_tty || inner.in_progress.is_empty() {
299            inner.status_lines_displayed = 0;
300            return;
301        }
302
303        let reset = "\x1b[0m";
304        let bold = "\x1b[1m";
305        let dim = "\x1b[2m";
306
307        // Build the done/in-progress/total label
308        let done_count = inner.completed.len();
309        let in_progress_count = inner.in_progress.len();
310        let failed_count = inner.failed_examples;
311
312        let failed_label = if failed_count > 0 {
313            format!(" {} failed ", failed_count)
314        } else {
315            String::new()
316        };
317
318        let range_label = format!(
319            " {}/{}/{} ",
320            done_count, in_progress_count, inner.total_examples
321        );
322
323        // Print a divider line with failed count on left, range label on right
324        let failed_visible_len = strip_ansi_len(&failed_label);
325        let range_visible_len = range_label.len();
326        let middle_divider_len = inner
327            .terminal_width
328            .saturating_sub(MARGIN * 2)
329            .saturating_sub(failed_visible_len)
330            .saturating_sub(range_visible_len);
331        let left_divider = "".repeat(MARGIN);
332        let middle_divider = "".repeat(middle_divider_len);
333        let right_divider = "".repeat(MARGIN);
334        eprintln!(
335            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
336        );
337
338        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
339        tasks.sort_by_key(|(name, _)| *name);
340
341        let total_tasks = tasks.len();
342        let mut lines_printed = 0;
343
344        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
345            let elapsed = format_duration(task.started_at.elapsed());
346            let substatus_part = task
347                .substatus
348                .as_ref()
349                .map(|s| truncate_with_ellipsis(s, 30))
350                .unwrap_or_default();
351
352            let step_label = task.step.label();
353            let step_color = task.step.color_code();
354            let name_width = inner.max_example_name_len;
355            let truncated_name = truncate_with_ellipsis(name, name_width);
356
357            let prefix = format!(
358                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
359                name = truncated_name,
360            );
361
362            let duration_with_margin = format!("{elapsed} ");
363            let padding_needed = inner
364                .terminal_width
365                .saturating_sub(MARGIN)
366                .saturating_sub(duration_with_margin.len())
367                .saturating_sub(strip_ansi_len(&prefix));
368            let padding = " ".repeat(padding_needed);
369
370            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
371            lines_printed += 1;
372        }
373
374        // Show "+N more" on its own line if there are more tasks
375        if total_tasks > MAX_STATUS_LINES {
376            let remaining = total_tasks - MAX_STATUS_LINES;
377            eprintln!("{:>12} +{remaining} more", "");
378            lines_printed += 1;
379        }
380
381        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
382        let _ = std::io::stderr().flush();
383    }
384
385    pub fn finalize(&self) {
386        let ticker = {
387            let mut inner = self.inner.lock().unwrap();
388            inner.ticker.take()
389        };
390
391        if let Some(ticker) = ticker {
392            let _ = ticker.join();
393        }
394
395        let mut inner = self.inner.lock().unwrap();
396        Self::clear_status_lines(&mut inner);
397
398        // Print summary if there were failures
399        if inner.failed_examples > 0 {
400            let total_examples = inner.total_examples;
401            let percentage = if total_examples > 0 {
402                inner.failed_examples as f64 / total_examples as f64 * 100.0
403            } else {
404                0.0
405            };
406            eprintln!(
407                "\n{} of {} examples failed ({:.1}%)",
408                inner.failed_examples, total_examples, percentage
409            );
410        }
411    }
412}
413
414pub struct StepProgress {
415    progress: Arc<Progress>,
416    step: Step,
417    example_name: String,
418}
419
420impl StepProgress {
421    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
422        let mut inner = self.progress.inner.lock().unwrap();
423        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
424            task.substatus = Some(substatus.into().into_owned());
425            Progress::clear_status_lines(&mut inner);
426            Progress::print_status_lines(&mut inner);
427        }
428    }
429
430    pub fn clear_substatus(&self) {
431        let mut inner = self.progress.inner.lock().unwrap();
432        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
433            task.substatus = None;
434            Progress::clear_status_lines(&mut inner);
435            Progress::print_status_lines(&mut inner);
436        }
437    }
438
439    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
440        let mut inner = self.progress.inner.lock().unwrap();
441        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
442            task.info = Some((info.into(), style));
443        }
444    }
445}
446
447impl Drop for StepProgress {
448    fn drop(&mut self) {
449        self.progress.finish(self.step, &self.example_name);
450    }
451}
452
453struct ProgressLogger;
454
455impl Log for ProgressLogger {
456    fn enabled(&self, metadata: &Metadata) -> bool {
457        metadata.level() <= Level::Info
458    }
459
460    fn log(&self, record: &Record) {
461        if !self.enabled(record.metadata()) {
462            return;
463        }
464
465        let level_color = match record.level() {
466            Level::Error => "\x1b[31m",
467            Level::Warn => "\x1b[33m",
468            Level::Info => "\x1b[32m",
469            Level::Debug => "\x1b[34m",
470            Level::Trace => "\x1b[35m",
471        };
472        let reset = "\x1b[0m";
473        let bold = "\x1b[1m";
474
475        let level_label = match record.level() {
476            Level::Error => "Error",
477            Level::Warn => "Warn",
478            Level::Info => "Info",
479            Level::Debug => "Debug",
480            Level::Trace => "Trace",
481        };
482
483        let message = format!(
484            "{bold}{level_color}{level_label:>12}{reset} {}",
485            record.args()
486        );
487
488        if let Some(progress) = GLOBAL.get() {
489            progress.log(&message);
490        } else {
491            eprintln!("{}", message);
492        }
493    }
494
495    fn flush(&self) {
496        let _ = std::io::stderr().flush();
497    }
498}
499
500#[cfg(unix)]
501fn get_terminal_width() -> usize {
502    unsafe {
503        let mut winsize: libc::winsize = std::mem::zeroed();
504        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
505            && winsize.ws_col > 0
506        {
507            winsize.ws_col as usize
508        } else {
509            80
510        }
511    }
512}
513
514#[cfg(not(unix))]
515fn get_terminal_width() -> usize {
516    80
517}
518
519fn strip_ansi_len(s: &str) -> usize {
520    let mut len = 0;
521    let mut in_escape = false;
522    for c in s.chars() {
523        if c == '\x1b' {
524            in_escape = true;
525        } else if in_escape {
526            if c == 'm' {
527                in_escape = false;
528            }
529        } else {
530            len += 1;
531        }
532    }
533    len
534}
535
536fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
537    if s.len() <= max_len {
538        s.to_string()
539    } else {
540        format!("{}", &s[..max_len.saturating_sub(1)])
541    }
542}
543
544fn format_duration(duration: Duration) -> String {
545    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
546
547    let millis = duration.as_millis() as f32;
548    if millis < 1000.0 {
549        format!("{}ms", millis)
550    } else if millis < MINUTE_IN_MILLIS {
551        format!("{:.1}s", millis / 1_000.0)
552    } else {
553        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
554    }
555}