progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use crate::paths::RUN_DIR;
 10
 11use log::{Level, Log, Metadata, Record};
 12
 13pub struct Progress {
 14    inner: Mutex<ProgressInner>,
 15}
 16
 17struct ProgressInner {
 18    completed: Vec<CompletedTask>,
 19    in_progress: HashMap<String, InProgressTask>,
 20    is_tty: bool,
 21    terminal_width: usize,
 22    max_example_name_len: usize,
 23    status_lines_displayed: usize,
 24    total_examples: usize,
 25    completed_examples: usize,
 26    failed_examples: usize,
 27    last_line_is_logging: bool,
 28    ticker: Option<std::thread::JoinHandle<()>>,
 29}
 30
 31#[derive(Clone)]
 32struct InProgressTask {
 33    step: Step,
 34    started_at: Instant,
 35    substatus: Option<String>,
 36    info: Option<(String, InfoStyle)>,
 37}
 38
 39struct CompletedTask {
 40    step: Step,
 41    example_name: String,
 42    duration: Duration,
 43    info: Option<(String, InfoStyle)>,
 44}
 45
 46#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 47pub enum Step {
 48    LoadProject,
 49    Context,
 50    FormatPrompt,
 51    Predict,
 52    Score,
 53    Qa,
 54    Repair,
 55    Synthesize,
 56    PullExamples,
 57}
 58
 59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 60pub enum InfoStyle {
 61    Normal,
 62    Warning,
 63}
 64
 65impl Step {
 66    pub fn label(&self) -> &'static str {
 67        match self {
 68            Step::LoadProject => "Load",
 69            Step::Context => "Context",
 70            Step::FormatPrompt => "Format",
 71            Step::Predict => "Predict",
 72            Step::Score => "Score",
 73            Step::Qa => "QA",
 74            Step::Repair => "Repair",
 75            Step::Synthesize => "Synthesize",
 76            Step::PullExamples => "Pull",
 77        }
 78    }
 79
 80    fn color_code(&self) -> &'static str {
 81        match self {
 82            Step::LoadProject => "\x1b[33m",
 83            Step::Context => "\x1b[35m",
 84            Step::FormatPrompt => "\x1b[34m",
 85            Step::Predict => "\x1b[32m",
 86            Step::Score => "\x1b[31m",
 87            Step::Qa => "\x1b[36m",
 88            Step::Repair => "\x1b[95m",
 89            Step::Synthesize => "\x1b[36m",
 90            Step::PullExamples => "\x1b[36m",
 91        }
 92    }
 93}
 94
 95static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 96static LOGGER: ProgressLogger = ProgressLogger;
 97
 98const MARGIN: usize = 4;
 99const MAX_STATUS_LINES: usize = 10;
100const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
101
102impl Progress {
103    /// Returns the global Progress instance, initializing it if necessary.
104    pub fn global() -> Arc<Progress> {
105        GLOBAL
106            .get_or_init(|| {
107                let progress = Arc::new(Self {
108                    inner: Mutex::new(ProgressInner {
109                        completed: Vec::new(),
110                        in_progress: HashMap::new(),
111                        is_tty: std::env::var("COLOR").is_ok()
112                            || (std::env::var("NO_COLOR").is_err()
113                                && std::io::stderr().is_terminal()),
114                        terminal_width: get_terminal_width(),
115                        max_example_name_len: 0,
116                        status_lines_displayed: 0,
117                        total_examples: 0,
118                        completed_examples: 0,
119                        failed_examples: 0,
120                        last_line_is_logging: false,
121                        ticker: None,
122                    }),
123                });
124                let _ = log::set_logger(&LOGGER);
125                log::set_max_level(log::LevelFilter::Info);
126                progress
127            })
128            .clone()
129    }
130
131    pub fn start_group(self: &Arc<Self>, example_name: &str) -> ExampleProgress {
132        ExampleProgress {
133            progress: self.clone(),
134            example_name: example_name.to_string(),
135        }
136    }
137
138    fn increment_completed(&self) {
139        let mut inner = self.inner.lock().unwrap();
140        inner.completed_examples += 1;
141    }
142
143    pub fn set_total_examples(&self, total: usize) {
144        let mut inner = self.inner.lock().unwrap();
145        inner.total_examples = total;
146    }
147
148    pub fn set_max_example_name_len(&self, example_names: impl Iterator<Item = impl AsRef<str>>) {
149        let mut inner = self.inner.lock().unwrap();
150        let max_name_width = inner
151            .terminal_width
152            .saturating_sub(MARGIN * 2)
153            .saturating_div(3)
154            .max(1);
155        inner.max_example_name_len = example_names
156            .map(|name| name.as_ref().len().min(max_name_width))
157            .max()
158            .unwrap_or(0);
159    }
160
161    pub fn increment_failed(&self) {
162        let mut inner = self.inner.lock().unwrap();
163        inner.failed_examples += 1;
164    }
165
166    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
167    /// This should be used for any output that needs to appear above the status lines.
168    fn log(&self, message: &str) {
169        let mut inner = self.inner.lock().unwrap();
170        Self::clear_status_lines(&mut inner);
171
172        if !inner.last_line_is_logging {
173            let reset = "\x1b[0m";
174            let dim = "\x1b[2m";
175            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
176            eprintln!("{dim}{divider}{reset}");
177            inner.last_line_is_logging = true;
178        }
179
180        let max_width = inner.terminal_width.saturating_sub(MARGIN);
181        for line in message.lines() {
182            let truncated = truncate_to_visible_width(line, max_width);
183            if truncated.len() < line.len() {
184                eprintln!("{}", truncated);
185            } else {
186                eprintln!("{}", truncated);
187            }
188        }
189
190        Self::print_status_lines(&mut inner);
191    }
192
193    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
194        let mut inner = self.inner.lock().unwrap();
195
196        Self::clear_status_lines(&mut inner);
197
198        // Update max_example_name_len if not already set via set_max_example_name_len
199        if inner.max_example_name_len == 0 {
200            let max_name_width = inner
201                .terminal_width
202                .saturating_sub(MARGIN * 2)
203                .saturating_div(3)
204                .max(1);
205            inner.max_example_name_len = example_name.len().min(max_name_width);
206        }
207        inner.in_progress.insert(
208            example_name.to_string(),
209            InProgressTask {
210                step,
211                started_at: Instant::now(),
212                substatus: None,
213                info: None,
214            },
215        );
216
217        if inner.is_tty && inner.ticker.is_none() {
218            let progress = self.clone();
219            inner.ticker = Some(std::thread::spawn(move || {
220                loop {
221                    std::thread::sleep(STATUS_TICK_INTERVAL);
222
223                    let mut inner = progress.inner.lock().unwrap();
224                    if inner.in_progress.is_empty() {
225                        break;
226                    }
227
228                    Progress::clear_status_lines(&mut inner);
229                    Progress::print_status_lines(&mut inner);
230                }
231            }));
232        }
233
234        Self::print_status_lines(&mut inner);
235
236        StepProgress {
237            progress: self.clone(),
238            step,
239            example_name: example_name.to_string(),
240        }
241    }
242
243    fn finish(&self, step: Step, example_name: &str) {
244        let mut inner = self.inner.lock().unwrap();
245
246        let Some(task) = inner.in_progress.remove(example_name) else {
247            return;
248        };
249
250        if task.step == step {
251            let duration = task.started_at.elapsed();
252
253            // Skip logging for tasks that complete quickly (under 500ms)
254            let should_print = duration >= Duration::from_millis(500);
255
256            inner.completed.push(CompletedTask {
257                step: task.step,
258                example_name: example_name.to_string(),
259                duration,
260                info: task.info,
261            });
262
263            Self::clear_status_lines(&mut inner);
264            if should_print {
265                Self::print_logging_closing_divider(&mut inner);
266                if let Some(last_completed) = inner.completed.last() {
267                    Self::print_completed(&inner, last_completed);
268                }
269            }
270            Self::print_status_lines(&mut inner);
271        } else {
272            inner.in_progress.insert(example_name.to_string(), task);
273        }
274    }
275
276    fn print_logging_closing_divider(inner: &mut ProgressInner) {
277        if inner.last_line_is_logging {
278            let reset = "\x1b[0m";
279            let dim = "\x1b[2m";
280            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
281            eprintln!("{dim}{divider}{reset}");
282            inner.last_line_is_logging = false;
283        }
284    }
285
286    fn clear_status_lines(inner: &mut ProgressInner) {
287        if inner.is_tty && inner.status_lines_displayed > 0 {
288            // Move up and clear each line we previously displayed
289            for _ in 0..inner.status_lines_displayed {
290                eprint!("\x1b[A\x1b[K");
291            }
292            inner.status_lines_displayed = 0;
293        }
294    }
295
296    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
297        let duration = format_duration(task.duration);
298        let name_width = inner.max_example_name_len;
299        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
300
301        if inner.is_tty {
302            let reset = "\x1b[0m";
303            let bold = "\x1b[1m";
304            let dim = "\x1b[2m";
305
306            let yellow = "\x1b[33m";
307            let info_part = task
308                .info
309                .as_ref()
310                .map(|(s, style)| {
311                    if *style == InfoStyle::Warning {
312                        format!("{yellow}{s}{reset}")
313                    } else {
314                        s.to_string()
315                    }
316                })
317                .unwrap_or_default();
318
319            let prefix = format!(
320                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
321                color = task.step.color_code(),
322                label = task.step.label(),
323                name = truncated_name,
324            );
325
326            let duration_with_margin = format!("{duration} ");
327            let padding_needed = inner
328                .terminal_width
329                .saturating_sub(MARGIN)
330                .saturating_sub(duration_with_margin.len())
331                .saturating_sub(strip_ansi_len(&prefix));
332            let padding = " ".repeat(padding_needed);
333
334            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
335        } else {
336            let info_part = task
337                .info
338                .as_ref()
339                .map(|(s, _)| format!(" | {}", s))
340                .unwrap_or_default();
341
342            eprintln!(
343                "{label:>12} {name:<name_width$}{info_part} {duration}",
344                label = task.step.label(),
345                name = truncate_with_ellipsis(&task.example_name, name_width),
346            );
347        }
348    }
349
350    fn print_status_lines(inner: &mut ProgressInner) {
351        if !inner.is_tty || inner.in_progress.is_empty() {
352            inner.status_lines_displayed = 0;
353            return;
354        }
355
356        let reset = "\x1b[0m";
357        let bold = "\x1b[1m";
358        let dim = "\x1b[2m";
359
360        // Build the done/in-progress/total label
361        let done_count = inner.completed_examples;
362        let in_progress_count = inner.in_progress.len();
363        let failed_count = inner.failed_examples;
364
365        let failed_label = if failed_count > 0 {
366            format!(" {} failed ", failed_count)
367        } else {
368            String::new()
369        };
370
371        let range_label = format!(
372            " {}/{}/{} ",
373            done_count, in_progress_count, inner.total_examples
374        );
375
376        // Print a divider line with failed count on left, range label on right
377        let failed_visible_len = strip_ansi_len(&failed_label);
378        let range_visible_len = range_label.len();
379        let middle_divider_len = inner
380            .terminal_width
381            .saturating_sub(MARGIN * 2)
382            .saturating_sub(failed_visible_len)
383            .saturating_sub(range_visible_len);
384        let left_divider = "".repeat(MARGIN);
385        let middle_divider = "".repeat(middle_divider_len);
386        let right_divider = "".repeat(MARGIN);
387        eprintln!(
388            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
389        );
390
391        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
392        tasks.sort_by_key(|(name, _)| *name);
393
394        let total_tasks = tasks.len();
395        let mut lines_printed = 0;
396
397        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
398            let elapsed = format_duration(task.started_at.elapsed());
399            let substatus_part = task
400                .substatus
401                .as_ref()
402                .map(|s| truncate_with_ellipsis(s, 30))
403                .unwrap_or_default();
404
405            let step_label = task.step.label();
406            let step_color = task.step.color_code();
407            let name_width = inner.max_example_name_len;
408            let truncated_name = truncate_with_ellipsis(name, name_width);
409
410            let prefix = format!(
411                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
412                name = truncated_name,
413            );
414
415            let duration_with_margin = format!("{elapsed} ");
416            let padding_needed = inner
417                .terminal_width
418                .saturating_sub(MARGIN)
419                .saturating_sub(duration_with_margin.len())
420                .saturating_sub(strip_ansi_len(&prefix));
421            let padding = " ".repeat(padding_needed);
422
423            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
424            lines_printed += 1;
425        }
426
427        // Show "+N more" on its own line if there are more tasks
428        if total_tasks > MAX_STATUS_LINES {
429            let remaining = total_tasks - MAX_STATUS_LINES;
430            eprintln!("{:>12} +{remaining} more", "");
431            lines_printed += 1;
432        }
433
434        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
435        let _ = std::io::stderr().flush();
436    }
437
438    pub fn finalize(&self) {
439        let ticker = {
440            let mut inner = self.inner.lock().unwrap();
441            inner.ticker.take()
442        };
443
444        if let Some(ticker) = ticker {
445            let _ = ticker.join();
446        }
447
448        let mut inner = self.inner.lock().unwrap();
449        Self::clear_status_lines(&mut inner);
450
451        // Print summary if there were failures
452        if inner.failed_examples > 0 {
453            let total_examples = inner.total_examples;
454            let percentage = if total_examples > 0 {
455                inner.failed_examples as f64 / total_examples as f64 * 100.0
456            } else {
457                0.0
458            };
459            let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
460            eprintln!(
461                "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
462                inner.failed_examples,
463                total_examples,
464                percentage,
465                failed_jsonl_path.display()
466            );
467        }
468    }
469}
470
471pub struct ExampleProgress {
472    progress: Arc<Progress>,
473    example_name: String,
474}
475
476impl ExampleProgress {
477    pub fn start(&self, step: Step) -> StepProgress {
478        self.progress.start(step, &self.example_name)
479    }
480}
481
482impl Drop for ExampleProgress {
483    fn drop(&mut self) {
484        self.progress.increment_completed();
485    }
486}
487
488pub struct StepProgress {
489    progress: Arc<Progress>,
490    step: Step,
491    example_name: String,
492}
493
494impl StepProgress {
495    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
496        let mut inner = self.progress.inner.lock().unwrap();
497        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
498            task.substatus = Some(substatus.into().into_owned());
499            Progress::clear_status_lines(&mut inner);
500            Progress::print_status_lines(&mut inner);
501        }
502    }
503
504    pub fn clear_substatus(&self) {
505        let mut inner = self.progress.inner.lock().unwrap();
506        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
507            task.substatus = None;
508            Progress::clear_status_lines(&mut inner);
509            Progress::print_status_lines(&mut inner);
510        }
511    }
512
513    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
514        let mut inner = self.progress.inner.lock().unwrap();
515        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
516            task.info = Some((info.into(), style));
517        }
518    }
519}
520
521impl Drop for StepProgress {
522    fn drop(&mut self) {
523        self.progress.finish(self.step, &self.example_name);
524    }
525}
526
527struct ProgressLogger;
528
529impl Log for ProgressLogger {
530    fn enabled(&self, metadata: &Metadata) -> bool {
531        metadata.level() <= Level::Info
532    }
533
534    fn log(&self, record: &Record) {
535        if !self.enabled(record.metadata()) {
536            return;
537        }
538
539        let level_color = match record.level() {
540            Level::Error => "\x1b[31m",
541            Level::Warn => "\x1b[33m",
542            Level::Info => "\x1b[32m",
543            Level::Debug => "\x1b[34m",
544            Level::Trace => "\x1b[35m",
545        };
546        let reset = "\x1b[0m";
547        let bold = "\x1b[1m";
548
549        let level_label = match record.level() {
550            Level::Error => "Error",
551            Level::Warn => "Warn",
552            Level::Info => "Info",
553            Level::Debug => "Debug",
554            Level::Trace => "Trace",
555        };
556
557        let message = format!(
558            "{bold}{level_color}{level_label:>12}{reset} {}",
559            record.args()
560        );
561
562        if let Some(progress) = GLOBAL.get() {
563            progress.log(&message);
564        } else {
565            eprintln!("{}", message);
566        }
567    }
568
569    fn flush(&self) {
570        let _ = std::io::stderr().flush();
571    }
572}
573
574#[cfg(unix)]
575fn get_terminal_width() -> usize {
576    unsafe {
577        let mut winsize: libc::winsize = std::mem::zeroed();
578        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
579            && winsize.ws_col > 0
580        {
581            winsize.ws_col as usize
582        } else {
583            80
584        }
585    }
586}
587
588#[cfg(not(unix))]
589fn get_terminal_width() -> usize {
590    80
591}
592
593fn strip_ansi_len(s: &str) -> usize {
594    let mut len = 0;
595    let mut in_escape = false;
596    for c in s.chars() {
597        if c == '\x1b' {
598            in_escape = true;
599        } else if in_escape {
600            if c == 'm' {
601                in_escape = false;
602            }
603        } else {
604            len += 1;
605        }
606    }
607    len
608}
609
610fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
611    if s.len() <= max_len {
612        Cow::Borrowed(s)
613    } else {
614        Cow::Owned(format!("{}", &s[..max_len.saturating_sub(1)]))
615    }
616}
617
618fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
619    let mut visible_len = 0;
620    let mut in_escape = false;
621    let mut last_byte_index = 0;
622    for (byte_index, c) in s.char_indices() {
623        if c == '\x1b' {
624            in_escape = true;
625        } else if in_escape {
626            if c == 'm' {
627                in_escape = false;
628            }
629        } else {
630            if visible_len >= max_visible_len {
631                return &s[..last_byte_index];
632            }
633            visible_len += 1;
634        }
635        last_byte_index = byte_index + c.len_utf8();
636    }
637    s
638}
639
640fn format_duration(duration: Duration) -> String {
641    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
642
643    let millis = duration.as_millis() as f32;
644    if millis < 1000.0 {
645        format!("{}ms", millis)
646    } else if millis < MINUTE_IN_MILLIS {
647        format!("{:.1}s", millis / 1_000.0)
648    } else {
649        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
650    }
651}