progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use log::{Level, Log, Metadata, Record};
 10
 11pub struct Progress {
 12    inner: Mutex<ProgressInner>,
 13}
 14
 15struct ProgressInner {
 16    completed: Vec<CompletedTask>,
 17    in_progress: HashMap<String, InProgressTask>,
 18    is_tty: bool,
 19    terminal_width: usize,
 20    max_example_name_len: usize,
 21    status_lines_displayed: usize,
 22    total_examples: usize,
 23    last_line_is_logging: bool,
 24}
 25
 26#[derive(Clone)]
 27struct InProgressTask {
 28    step: Step,
 29    started_at: Instant,
 30    substatus: Option<String>,
 31    info: Option<(String, InfoStyle)>,
 32}
 33
 34struct CompletedTask {
 35    step: Step,
 36    example_name: String,
 37    duration: Duration,
 38    info: Option<(String, InfoStyle)>,
 39}
 40
 41#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 42pub enum Step {
 43    LoadProject,
 44    Context,
 45    FormatPrompt,
 46    Predict,
 47    Score,
 48}
 49
 50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 51pub enum InfoStyle {
 52    Normal,
 53    Warning,
 54}
 55
 56impl Step {
 57    pub fn label(&self) -> &'static str {
 58        match self {
 59            Step::LoadProject => "Load",
 60            Step::Context => "Context",
 61            Step::FormatPrompt => "Format",
 62            Step::Predict => "Predict",
 63            Step::Score => "Score",
 64        }
 65    }
 66
 67    fn color_code(&self) -> &'static str {
 68        match self {
 69            Step::LoadProject => "\x1b[33m",
 70            Step::Context => "\x1b[35m",
 71            Step::FormatPrompt => "\x1b[34m",
 72            Step::Predict => "\x1b[32m",
 73            Step::Score => "\x1b[31m",
 74        }
 75    }
 76}
 77
 78static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 79static LOGGER: ProgressLogger = ProgressLogger;
 80
 81const RIGHT_MARGIN: usize = 4;
 82const MAX_STATUS_LINES: usize = 10;
 83
 84impl Progress {
 85    /// Returns the global Progress instance, initializing it if necessary.
 86    pub fn global() -> Arc<Progress> {
 87        GLOBAL
 88            .get_or_init(|| {
 89                let progress = Arc::new(Self {
 90                    inner: Mutex::new(ProgressInner {
 91                        completed: Vec::new(),
 92                        in_progress: HashMap::new(),
 93                        is_tty: std::io::stderr().is_terminal(),
 94                        terminal_width: get_terminal_width(),
 95                        max_example_name_len: 0,
 96                        status_lines_displayed: 0,
 97                        total_examples: 0,
 98                        last_line_is_logging: false,
 99                    }),
100                });
101                let _ = log::set_logger(&LOGGER);
102                log::set_max_level(log::LevelFilter::Error);
103                progress
104            })
105            .clone()
106    }
107
108    pub fn set_total_examples(&self, total: usize) {
109        let mut inner = self.inner.lock().unwrap();
110        inner.total_examples = total;
111    }
112
113    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
114    /// This should be used for any output that needs to appear above the status lines.
115    fn log(&self, message: &str) {
116        let mut inner = self.inner.lock().unwrap();
117        Self::clear_status_lines(&mut inner);
118
119        if !inner.last_line_is_logging {
120            let reset = "\x1b[0m";
121            let dim = "\x1b[2m";
122            let divider = "".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
123            eprintln!("{dim}{divider}{reset}");
124            inner.last_line_is_logging = true;
125        }
126
127        eprintln!("{}", message);
128    }
129
130    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
131        let mut inner = self.inner.lock().unwrap();
132
133        Self::clear_status_lines(&mut inner);
134
135        inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
136        inner.in_progress.insert(
137            example_name.to_string(),
138            InProgressTask {
139                step,
140                started_at: Instant::now(),
141                substatus: None,
142                info: None,
143            },
144        );
145
146        Self::print_status_lines(&mut inner);
147
148        StepProgress {
149            progress: self.clone(),
150            step,
151            example_name: example_name.to_string(),
152        }
153    }
154
155    fn finish(&self, step: Step, example_name: &str) {
156        let mut inner = self.inner.lock().unwrap();
157
158        let Some(task) = inner.in_progress.remove(example_name) else {
159            return;
160        };
161
162        if task.step == step {
163            inner.completed.push(CompletedTask {
164                step: task.step,
165                example_name: example_name.to_string(),
166                duration: task.started_at.elapsed(),
167                info: task.info,
168            });
169
170            Self::clear_status_lines(&mut inner);
171            Self::print_logging_closing_divider(&mut inner);
172            Self::print_completed(&inner, inner.completed.last().unwrap());
173            Self::print_status_lines(&mut inner);
174        } else {
175            inner.in_progress.insert(example_name.to_string(), task);
176        }
177    }
178
179    fn print_logging_closing_divider(inner: &mut ProgressInner) {
180        if inner.last_line_is_logging {
181            let reset = "\x1b[0m";
182            let dim = "\x1b[2m";
183            let divider = "".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
184            eprintln!("{dim}{divider}{reset}");
185            inner.last_line_is_logging = false;
186        }
187    }
188
189    fn clear_status_lines(inner: &mut ProgressInner) {
190        if inner.is_tty && inner.status_lines_displayed > 0 {
191            // Move up and clear each line we previously displayed
192            for _ in 0..inner.status_lines_displayed {
193                eprint!("\x1b[A\x1b[K");
194            }
195            let _ = std::io::stderr().flush();
196            inner.status_lines_displayed = 0;
197        }
198    }
199
200    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
201        let duration = format_duration(task.duration);
202        let name_width = inner.max_example_name_len;
203
204        if inner.is_tty {
205            let reset = "\x1b[0m";
206            let bold = "\x1b[1m";
207            let dim = "\x1b[2m";
208
209            let yellow = "\x1b[33m";
210            let info_part = task
211                .info
212                .as_ref()
213                .map(|(s, style)| {
214                    if *style == InfoStyle::Warning {
215                        format!("{yellow}{s}{reset}")
216                    } else {
217                        s.to_string()
218                    }
219                })
220                .unwrap_or_default();
221
222            let prefix = format!(
223                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
224                color = task.step.color_code(),
225                label = task.step.label(),
226                name = task.example_name,
227            );
228
229            let duration_with_margin = format!("{duration} ");
230            let padding_needed = inner
231                .terminal_width
232                .saturating_sub(RIGHT_MARGIN)
233                .saturating_sub(duration_with_margin.len())
234                .saturating_sub(strip_ansi_len(&prefix));
235            let padding = " ".repeat(padding_needed);
236
237            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
238        } else {
239            let info_part = task
240                .info
241                .as_ref()
242                .map(|(s, _)| format!(" | {}", s))
243                .unwrap_or_default();
244
245            eprintln!(
246                "{label:>12} {name:<name_width$}{info_part} {duration}",
247                label = task.step.label(),
248                name = task.example_name,
249            );
250        }
251    }
252
253    fn print_status_lines(inner: &mut ProgressInner) {
254        if !inner.is_tty || inner.in_progress.is_empty() {
255            inner.status_lines_displayed = 0;
256            return;
257        }
258
259        let reset = "\x1b[0m";
260        let bold = "\x1b[1m";
261        let dim = "\x1b[2m";
262
263        // Build the done/in-progress/total label
264        let done_count = inner.completed.len();
265        let in_progress_count = inner.in_progress.len();
266        let range_label = format!(
267            " {}/{}/{} ",
268            done_count, in_progress_count, inner.total_examples
269        );
270
271        // Print a divider line with range label aligned with timestamps
272        let range_visible_len = range_label.len();
273        let left_divider_len = inner
274            .terminal_width
275            .saturating_sub(RIGHT_MARGIN)
276            .saturating_sub(range_visible_len);
277        let left_divider = "".repeat(left_divider_len);
278        let right_divider = "".repeat(RIGHT_MARGIN);
279        eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
280
281        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
282        tasks.sort_by_key(|(name, _)| *name);
283
284        let total_tasks = tasks.len();
285        let mut lines_printed = 0;
286
287        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
288            let elapsed = format_duration(task.started_at.elapsed());
289            let substatus_part = task
290                .substatus
291                .as_ref()
292                .map(|s| truncate_with_ellipsis(s, 30))
293                .unwrap_or_default();
294
295            let step_label = task.step.label();
296            let step_color = task.step.color_code();
297            let name_width = inner.max_example_name_len;
298
299            let prefix = format!(
300                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
301                name = name,
302            );
303
304            let duration_with_margin = format!("{elapsed} ");
305            let padding_needed = inner
306                .terminal_width
307                .saturating_sub(RIGHT_MARGIN)
308                .saturating_sub(duration_with_margin.len())
309                .saturating_sub(strip_ansi_len(&prefix));
310            let padding = " ".repeat(padding_needed);
311
312            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
313            lines_printed += 1;
314        }
315
316        // Show "+N more" on its own line if there are more tasks
317        if total_tasks > MAX_STATUS_LINES {
318            let remaining = total_tasks - MAX_STATUS_LINES;
319            eprintln!("{:>12} +{remaining} more", "");
320            lines_printed += 1;
321        }
322
323        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
324        let _ = std::io::stderr().flush();
325    }
326
327    pub fn clear(&self) {
328        let mut inner = self.inner.lock().unwrap();
329        Self::clear_status_lines(&mut inner);
330    }
331}
332
333pub struct StepProgress {
334    progress: Arc<Progress>,
335    step: Step,
336    example_name: String,
337}
338
339impl StepProgress {
340    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
341        let mut inner = self.progress.inner.lock().unwrap();
342        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
343            task.substatus = Some(substatus.into().into_owned());
344            Progress::clear_status_lines(&mut inner);
345            Progress::print_status_lines(&mut inner);
346        }
347    }
348
349    pub fn clear_substatus(&self) {
350        let mut inner = self.progress.inner.lock().unwrap();
351        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
352            task.substatus = None;
353            Progress::clear_status_lines(&mut inner);
354            Progress::print_status_lines(&mut inner);
355        }
356    }
357
358    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
359        let mut inner = self.progress.inner.lock().unwrap();
360        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
361            task.info = Some((info.into(), style));
362        }
363    }
364}
365
366impl Drop for StepProgress {
367    fn drop(&mut self) {
368        self.progress.finish(self.step, &self.example_name);
369    }
370}
371
372struct ProgressLogger;
373
374impl Log for ProgressLogger {
375    fn enabled(&self, metadata: &Metadata) -> bool {
376        metadata.level() <= Level::Info
377    }
378
379    fn log(&self, record: &Record) {
380        if !self.enabled(record.metadata()) {
381            return;
382        }
383
384        let level_color = match record.level() {
385            Level::Error => "\x1b[31m",
386            Level::Warn => "\x1b[33m",
387            Level::Info => "\x1b[32m",
388            Level::Debug => "\x1b[34m",
389            Level::Trace => "\x1b[35m",
390        };
391        let reset = "\x1b[0m";
392        let bold = "\x1b[1m";
393
394        let level_label = match record.level() {
395            Level::Error => "Error",
396            Level::Warn => "Warn",
397            Level::Info => "Info",
398            Level::Debug => "Debug",
399            Level::Trace => "Trace",
400        };
401
402        let message = format!(
403            "{bold}{level_color}{level_label:>12}{reset} {}",
404            record.args()
405        );
406
407        if let Some(progress) = GLOBAL.get() {
408            progress.log(&message);
409        } else {
410            eprintln!("{}", message);
411        }
412    }
413
414    fn flush(&self) {
415        let _ = std::io::stderr().flush();
416    }
417}
418
419#[cfg(unix)]
420fn get_terminal_width() -> usize {
421    unsafe {
422        let mut winsize: libc::winsize = std::mem::zeroed();
423        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
424            && winsize.ws_col > 0
425        {
426            winsize.ws_col as usize
427        } else {
428            80
429        }
430    }
431}
432
433#[cfg(not(unix))]
434fn get_terminal_width() -> usize {
435    80
436}
437
438fn strip_ansi_len(s: &str) -> usize {
439    let mut len = 0;
440    let mut in_escape = false;
441    for c in s.chars() {
442        if c == '\x1b' {
443            in_escape = true;
444        } else if in_escape {
445            if c == 'm' {
446                in_escape = false;
447            }
448        } else {
449            len += 1;
450        }
451    }
452    len
453}
454
455fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
456    if s.len() <= max_len {
457        s.to_string()
458    } else {
459        format!("{}", &s[..max_len.saturating_sub(1)])
460    }
461}
462
463fn format_duration(duration: Duration) -> String {
464    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
465
466    let millis = duration.as_millis() as f32;
467    if millis < 1000.0 {
468        format!("{}ms", millis)
469    } else if millis < MINUTE_IN_MILLIS {
470        format!("{:.1}s", millis / 1_000.0)
471    } else {
472        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
473    }
474}