progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use crate::paths::RUN_DIR;
 10
 11use log::{Level, Log, Metadata, Record};
 12
 13pub struct Progress {
 14    inner: Mutex<ProgressInner>,
 15}
 16
 17struct ProgressInner {
 18    completed: Vec<CompletedTask>,
 19    in_progress: HashMap<String, InProgressTask>,
 20    is_tty: bool,
 21    terminal_width: usize,
 22    max_example_name_len: usize,
 23    status_lines_displayed: usize,
 24    total_examples: usize,
 25    completed_examples: usize,
 26    failed_examples: usize,
 27    last_line_is_logging: bool,
 28    ticker: Option<std::thread::JoinHandle<()>>,
 29}
 30
 31#[derive(Clone)]
 32struct InProgressTask {
 33    step: Step,
 34    started_at: Instant,
 35    substatus: Option<String>,
 36    info: Option<(String, InfoStyle)>,
 37}
 38
 39struct CompletedTask {
 40    step: Step,
 41    example_name: String,
 42    duration: Duration,
 43    info: Option<(String, InfoStyle)>,
 44}
 45
 46#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 47pub enum Step {
 48    LoadProject,
 49    Context,
 50    FormatPrompt,
 51    Predict,
 52    Score,
 53    Synthesize,
 54    PullExamples,
 55}
 56
 57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 58pub enum InfoStyle {
 59    Normal,
 60    Warning,
 61}
 62
 63impl Step {
 64    pub fn label(&self) -> &'static str {
 65        match self {
 66            Step::LoadProject => "Load",
 67            Step::Context => "Context",
 68            Step::FormatPrompt => "Format",
 69            Step::Predict => "Predict",
 70            Step::Score => "Score",
 71            Step::Synthesize => "Synthesize",
 72            Step::PullExamples => "Pull",
 73        }
 74    }
 75
 76    fn color_code(&self) -> &'static str {
 77        match self {
 78            Step::LoadProject => "\x1b[33m",
 79            Step::Context => "\x1b[35m",
 80            Step::FormatPrompt => "\x1b[34m",
 81            Step::Predict => "\x1b[32m",
 82            Step::Score => "\x1b[31m",
 83            Step::Synthesize => "\x1b[36m",
 84            Step::PullExamples => "\x1b[36m",
 85        }
 86    }
 87}
 88
 89static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 90static LOGGER: ProgressLogger = ProgressLogger;
 91
 92const MARGIN: usize = 4;
 93const MAX_STATUS_LINES: usize = 10;
 94const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 95
 96impl Progress {
 97    /// Returns the global Progress instance, initializing it if necessary.
 98    pub fn global() -> Arc<Progress> {
 99        GLOBAL
100            .get_or_init(|| {
101                let progress = Arc::new(Self {
102                    inner: Mutex::new(ProgressInner {
103                        completed: Vec::new(),
104                        in_progress: HashMap::new(),
105                        is_tty: std::env::var("COLOR").is_ok()
106                            || (std::env::var("NO_COLOR").is_err()
107                                && std::io::stderr().is_terminal()),
108                        terminal_width: get_terminal_width(),
109                        max_example_name_len: 0,
110                        status_lines_displayed: 0,
111                        total_examples: 0,
112                        completed_examples: 0,
113                        failed_examples: 0,
114                        last_line_is_logging: false,
115                        ticker: None,
116                    }),
117                });
118                let _ = log::set_logger(&LOGGER);
119                log::set_max_level(log::LevelFilter::Info);
120                progress
121            })
122            .clone()
123    }
124
125    pub fn start_group(self: &Arc<Self>, example_name: &str) -> ExampleProgress {
126        ExampleProgress {
127            progress: self.clone(),
128            example_name: example_name.to_string(),
129        }
130    }
131
132    fn increment_completed(&self) {
133        let mut inner = self.inner.lock().unwrap();
134        inner.completed_examples += 1;
135    }
136
137    pub fn set_total_examples(&self, total: usize) {
138        let mut inner = self.inner.lock().unwrap();
139        inner.total_examples = total;
140    }
141
142    pub fn increment_failed(&self) {
143        let mut inner = self.inner.lock().unwrap();
144        inner.failed_examples += 1;
145    }
146
147    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
148    /// This should be used for any output that needs to appear above the status lines.
149    fn log(&self, message: &str) {
150        let mut inner = self.inner.lock().unwrap();
151        Self::clear_status_lines(&mut inner);
152
153        if !inner.last_line_is_logging {
154            let reset = "\x1b[0m";
155            let dim = "\x1b[2m";
156            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
157            eprintln!("{dim}{divider}{reset}");
158            inner.last_line_is_logging = true;
159        }
160
161        let max_width = inner.terminal_width.saturating_sub(MARGIN);
162        for line in message.lines() {
163            let truncated = truncate_to_visible_width(line, max_width);
164            if truncated.len() < line.len() {
165                eprintln!("{}", truncated);
166            } else {
167                eprintln!("{}", truncated);
168            }
169        }
170
171        Self::print_status_lines(&mut inner);
172    }
173
174    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
175        let mut inner = self.inner.lock().unwrap();
176
177        Self::clear_status_lines(&mut inner);
178
179        let max_name_width = inner
180            .terminal_width
181            .saturating_sub(MARGIN * 2)
182            .saturating_div(3)
183            .max(1);
184        inner.max_example_name_len = inner
185            .max_example_name_len
186            .max(example_name.len().min(max_name_width));
187        inner.in_progress.insert(
188            example_name.to_string(),
189            InProgressTask {
190                step,
191                started_at: Instant::now(),
192                substatus: None,
193                info: None,
194            },
195        );
196
197        if inner.is_tty && inner.ticker.is_none() {
198            let progress = self.clone();
199            inner.ticker = Some(std::thread::spawn(move || {
200                loop {
201                    std::thread::sleep(STATUS_TICK_INTERVAL);
202
203                    let mut inner = progress.inner.lock().unwrap();
204                    if inner.in_progress.is_empty() {
205                        break;
206                    }
207
208                    Progress::clear_status_lines(&mut inner);
209                    Progress::print_status_lines(&mut inner);
210                }
211            }));
212        }
213
214        Self::print_status_lines(&mut inner);
215
216        StepProgress {
217            progress: self.clone(),
218            step,
219            example_name: example_name.to_string(),
220        }
221    }
222
223    fn finish(&self, step: Step, example_name: &str) {
224        let mut inner = self.inner.lock().unwrap();
225
226        let Some(task) = inner.in_progress.remove(example_name) else {
227            return;
228        };
229
230        if task.step == step {
231            inner.completed.push(CompletedTask {
232                step: task.step,
233                example_name: example_name.to_string(),
234                duration: task.started_at.elapsed(),
235                info: task.info,
236            });
237
238            Self::clear_status_lines(&mut inner);
239            Self::print_logging_closing_divider(&mut inner);
240            if let Some(last_completed) = inner.completed.last() {
241                Self::print_completed(&inner, last_completed);
242            }
243            Self::print_status_lines(&mut inner);
244        } else {
245            inner.in_progress.insert(example_name.to_string(), task);
246        }
247    }
248
249    fn print_logging_closing_divider(inner: &mut ProgressInner) {
250        if inner.last_line_is_logging {
251            let reset = "\x1b[0m";
252            let dim = "\x1b[2m";
253            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
254            eprintln!("{dim}{divider}{reset}");
255            inner.last_line_is_logging = false;
256        }
257    }
258
259    fn clear_status_lines(inner: &mut ProgressInner) {
260        if inner.is_tty && inner.status_lines_displayed > 0 {
261            // Move up and clear each line we previously displayed
262            for _ in 0..inner.status_lines_displayed {
263                eprint!("\x1b[A\x1b[K");
264            }
265            inner.status_lines_displayed = 0;
266        }
267    }
268
269    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
270        let duration = format_duration(task.duration);
271        let name_width = inner.max_example_name_len;
272        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
273
274        if inner.is_tty {
275            let reset = "\x1b[0m";
276            let bold = "\x1b[1m";
277            let dim = "\x1b[2m";
278
279            let yellow = "\x1b[33m";
280            let info_part = task
281                .info
282                .as_ref()
283                .map(|(s, style)| {
284                    if *style == InfoStyle::Warning {
285                        format!("{yellow}{s}{reset}")
286                    } else {
287                        s.to_string()
288                    }
289                })
290                .unwrap_or_default();
291
292            let prefix = format!(
293                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
294                color = task.step.color_code(),
295                label = task.step.label(),
296                name = truncated_name,
297            );
298
299            let duration_with_margin = format!("{duration} ");
300            let padding_needed = inner
301                .terminal_width
302                .saturating_sub(MARGIN)
303                .saturating_sub(duration_with_margin.len())
304                .saturating_sub(strip_ansi_len(&prefix));
305            let padding = " ".repeat(padding_needed);
306
307            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
308        } else {
309            let info_part = task
310                .info
311                .as_ref()
312                .map(|(s, _)| format!(" | {}", s))
313                .unwrap_or_default();
314
315            eprintln!(
316                "{label:>12} {name:<name_width$}{info_part} {duration}",
317                label = task.step.label(),
318                name = truncate_with_ellipsis(&task.example_name, name_width),
319            );
320        }
321    }
322
323    fn print_status_lines(inner: &mut ProgressInner) {
324        if !inner.is_tty || inner.in_progress.is_empty() {
325            inner.status_lines_displayed = 0;
326            return;
327        }
328
329        let reset = "\x1b[0m";
330        let bold = "\x1b[1m";
331        let dim = "\x1b[2m";
332
333        // Build the done/in-progress/total label
334        let done_count = inner.completed_examples;
335        let in_progress_count = inner.in_progress.len();
336        let failed_count = inner.failed_examples;
337
338        let failed_label = if failed_count > 0 {
339            format!(" {} failed ", failed_count)
340        } else {
341            String::new()
342        };
343
344        let range_label = format!(
345            " {}/{}/{} ",
346            done_count, in_progress_count, inner.total_examples
347        );
348
349        // Print a divider line with failed count on left, range label on right
350        let failed_visible_len = strip_ansi_len(&failed_label);
351        let range_visible_len = range_label.len();
352        let middle_divider_len = inner
353            .terminal_width
354            .saturating_sub(MARGIN * 2)
355            .saturating_sub(failed_visible_len)
356            .saturating_sub(range_visible_len);
357        let left_divider = "".repeat(MARGIN);
358        let middle_divider = "".repeat(middle_divider_len);
359        let right_divider = "".repeat(MARGIN);
360        eprintln!(
361            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
362        );
363
364        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
365        tasks.sort_by_key(|(name, _)| *name);
366
367        let total_tasks = tasks.len();
368        let mut lines_printed = 0;
369
370        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
371            let elapsed = format_duration(task.started_at.elapsed());
372            let substatus_part = task
373                .substatus
374                .as_ref()
375                .map(|s| truncate_with_ellipsis(s, 30))
376                .unwrap_or_default();
377
378            let step_label = task.step.label();
379            let step_color = task.step.color_code();
380            let name_width = inner.max_example_name_len;
381            let truncated_name = truncate_with_ellipsis(name, name_width);
382
383            let prefix = format!(
384                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
385                name = truncated_name,
386            );
387
388            let duration_with_margin = format!("{elapsed} ");
389            let padding_needed = inner
390                .terminal_width
391                .saturating_sub(MARGIN)
392                .saturating_sub(duration_with_margin.len())
393                .saturating_sub(strip_ansi_len(&prefix));
394            let padding = " ".repeat(padding_needed);
395
396            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
397            lines_printed += 1;
398        }
399
400        // Show "+N more" on its own line if there are more tasks
401        if total_tasks > MAX_STATUS_LINES {
402            let remaining = total_tasks - MAX_STATUS_LINES;
403            eprintln!("{:>12} +{remaining} more", "");
404            lines_printed += 1;
405        }
406
407        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
408        let _ = std::io::stderr().flush();
409    }
410
411    pub fn finalize(&self) {
412        let ticker = {
413            let mut inner = self.inner.lock().unwrap();
414            inner.ticker.take()
415        };
416
417        if let Some(ticker) = ticker {
418            let _ = ticker.join();
419        }
420
421        let mut inner = self.inner.lock().unwrap();
422        Self::clear_status_lines(&mut inner);
423
424        // Print summary if there were failures
425        if inner.failed_examples > 0 {
426            let total_examples = inner.total_examples;
427            let percentage = if total_examples > 0 {
428                inner.failed_examples as f64 / total_examples as f64 * 100.0
429            } else {
430                0.0
431            };
432            let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
433            eprintln!(
434                "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
435                inner.failed_examples,
436                total_examples,
437                percentage,
438                failed_jsonl_path.display()
439            );
440        }
441    }
442}
443
444pub struct ExampleProgress {
445    progress: Arc<Progress>,
446    example_name: String,
447}
448
449impl ExampleProgress {
450    pub fn start(&self, step: Step) -> StepProgress {
451        self.progress.start(step, &self.example_name)
452    }
453}
454
455impl Drop for ExampleProgress {
456    fn drop(&mut self) {
457        self.progress.increment_completed();
458    }
459}
460
461pub struct StepProgress {
462    progress: Arc<Progress>,
463    step: Step,
464    example_name: String,
465}
466
467impl StepProgress {
468    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
469        let mut inner = self.progress.inner.lock().unwrap();
470        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
471            task.substatus = Some(substatus.into().into_owned());
472            Progress::clear_status_lines(&mut inner);
473            Progress::print_status_lines(&mut inner);
474        }
475    }
476
477    pub fn clear_substatus(&self) {
478        let mut inner = self.progress.inner.lock().unwrap();
479        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
480            task.substatus = None;
481            Progress::clear_status_lines(&mut inner);
482            Progress::print_status_lines(&mut inner);
483        }
484    }
485
486    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
487        let mut inner = self.progress.inner.lock().unwrap();
488        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
489            task.info = Some((info.into(), style));
490        }
491    }
492}
493
494impl Drop for StepProgress {
495    fn drop(&mut self) {
496        self.progress.finish(self.step, &self.example_name);
497    }
498}
499
500struct ProgressLogger;
501
502impl Log for ProgressLogger {
503    fn enabled(&self, metadata: &Metadata) -> bool {
504        metadata.level() <= Level::Info
505    }
506
507    fn log(&self, record: &Record) {
508        if !self.enabled(record.metadata()) {
509            return;
510        }
511
512        let level_color = match record.level() {
513            Level::Error => "\x1b[31m",
514            Level::Warn => "\x1b[33m",
515            Level::Info => "\x1b[32m",
516            Level::Debug => "\x1b[34m",
517            Level::Trace => "\x1b[35m",
518        };
519        let reset = "\x1b[0m";
520        let bold = "\x1b[1m";
521
522        let level_label = match record.level() {
523            Level::Error => "Error",
524            Level::Warn => "Warn",
525            Level::Info => "Info",
526            Level::Debug => "Debug",
527            Level::Trace => "Trace",
528        };
529
530        let message = format!(
531            "{bold}{level_color}{level_label:>12}{reset} {}",
532            record.args()
533        );
534
535        if let Some(progress) = GLOBAL.get() {
536            progress.log(&message);
537        } else {
538            eprintln!("{}", message);
539        }
540    }
541
542    fn flush(&self) {
543        let _ = std::io::stderr().flush();
544    }
545}
546
547#[cfg(unix)]
548fn get_terminal_width() -> usize {
549    unsafe {
550        let mut winsize: libc::winsize = std::mem::zeroed();
551        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
552            && winsize.ws_col > 0
553        {
554            winsize.ws_col as usize
555        } else {
556            80
557        }
558    }
559}
560
561#[cfg(not(unix))]
562fn get_terminal_width() -> usize {
563    80
564}
565
566fn strip_ansi_len(s: &str) -> usize {
567    let mut len = 0;
568    let mut in_escape = false;
569    for c in s.chars() {
570        if c == '\x1b' {
571            in_escape = true;
572        } else if in_escape {
573            if c == 'm' {
574                in_escape = false;
575            }
576        } else {
577            len += 1;
578        }
579    }
580    len
581}
582
583fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
584    if s.len() <= max_len {
585        Cow::Borrowed(s)
586    } else {
587        Cow::Owned(format!("{}", &s[..max_len.saturating_sub(1)]))
588    }
589}
590
591fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
592    let mut visible_len = 0;
593    let mut in_escape = false;
594    let mut last_byte_index = 0;
595    for (byte_index, c) in s.char_indices() {
596        if c == '\x1b' {
597            in_escape = true;
598        } else if in_escape {
599            if c == 'm' {
600                in_escape = false;
601            }
602        } else {
603            if visible_len >= max_visible_len {
604                return &s[..last_byte_index];
605            }
606            visible_len += 1;
607        }
608        last_byte_index = byte_index + c.len_utf8();
609    }
610    s
611}
612
613fn format_duration(duration: Duration) -> String {
614    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
615
616    let millis = duration.as_millis() as f32;
617    if millis < 1000.0 {
618        format!("{}ms", millis)
619    } else if millis < MINUTE_IN_MILLIS {
620        format!("{:.1}s", millis / 1_000.0)
621    } else {
622        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
623    }
624}