progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use crate::paths::RUN_DIR;
 10
 11use log::{Level, Log, Metadata, Record};
 12
 13pub struct Progress {
 14    inner: Mutex<ProgressInner>,
 15}
 16
 17struct ProgressInner {
 18    completed: Vec<CompletedTask>,
 19    in_progress: HashMap<String, InProgressTask>,
 20    is_tty: bool,
 21    terminal_width: usize,
 22    max_example_name_len: usize,
 23    status_lines_displayed: usize,
 24    total_examples: usize,
 25    completed_examples: usize,
 26    failed_examples: usize,
 27    last_line_is_logging: bool,
 28    ticker: Option<std::thread::JoinHandle<()>>,
 29}
 30
 31#[derive(Clone)]
 32struct InProgressTask {
 33    step: Step,
 34    started_at: Instant,
 35    substatus: Option<String>,
 36    info: Option<(String, InfoStyle)>,
 37}
 38
 39struct CompletedTask {
 40    step: Step,
 41    example_name: String,
 42    duration: Duration,
 43    info: Option<(String, InfoStyle)>,
 44}
 45
 46#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 47pub enum Step {
 48    LoadProject,
 49    Context,
 50    FormatPrompt,
 51    Predict,
 52    Score,
 53    Synthesize,
 54    PullExamples,
 55}
 56
 57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 58pub enum InfoStyle {
 59    Normal,
 60    Warning,
 61}
 62
 63impl Step {
 64    pub fn label(&self) -> &'static str {
 65        match self {
 66            Step::LoadProject => "Load",
 67            Step::Context => "Context",
 68            Step::FormatPrompt => "Format",
 69            Step::Predict => "Predict",
 70            Step::Score => "Score",
 71            Step::Synthesize => "Synthesize",
 72            Step::PullExamples => "Pull",
 73        }
 74    }
 75
 76    fn color_code(&self) -> &'static str {
 77        match self {
 78            Step::LoadProject => "\x1b[33m",
 79            Step::Context => "\x1b[35m",
 80            Step::FormatPrompt => "\x1b[34m",
 81            Step::Predict => "\x1b[32m",
 82            Step::Score => "\x1b[31m",
 83            Step::Synthesize => "\x1b[36m",
 84            Step::PullExamples => "\x1b[36m",
 85        }
 86    }
 87}
 88
 89static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 90static LOGGER: ProgressLogger = ProgressLogger;
 91
 92const MARGIN: usize = 4;
 93const MAX_STATUS_LINES: usize = 10;
 94const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 95
 96impl Progress {
 97    /// Returns the global Progress instance, initializing it if necessary.
 98    pub fn global() -> Arc<Progress> {
 99        GLOBAL
100            .get_or_init(|| {
101                let progress = Arc::new(Self {
102                    inner: Mutex::new(ProgressInner {
103                        completed: Vec::new(),
104                        in_progress: HashMap::new(),
105                        is_tty: std::env::var("COLOR").is_ok()
106                            || (std::env::var("NO_COLOR").is_err()
107                                && std::io::stderr().is_terminal()),
108                        terminal_width: get_terminal_width(),
109                        max_example_name_len: 0,
110                        status_lines_displayed: 0,
111                        total_examples: 0,
112                        completed_examples: 0,
113                        failed_examples: 0,
114                        last_line_is_logging: false,
115                        ticker: None,
116                    }),
117                });
118                let _ = log::set_logger(&LOGGER);
119                log::set_max_level(log::LevelFilter::Info);
120                progress
121            })
122            .clone()
123    }
124
125    pub fn start_group(self: &Arc<Self>, example_name: &str) -> ExampleProgress {
126        ExampleProgress {
127            progress: self.clone(),
128            example_name: example_name.to_string(),
129        }
130    }
131
132    fn increment_completed(&self) {
133        let mut inner = self.inner.lock().unwrap();
134        inner.completed_examples += 1;
135    }
136
137    pub fn set_total_examples(&self, total: usize) {
138        let mut inner = self.inner.lock().unwrap();
139        inner.total_examples = total;
140    }
141
142    pub fn set_max_example_name_len(&self, example_names: impl Iterator<Item = impl AsRef<str>>) {
143        let mut inner = self.inner.lock().unwrap();
144        let max_name_width = inner
145            .terminal_width
146            .saturating_sub(MARGIN * 2)
147            .saturating_div(3)
148            .max(1);
149        inner.max_example_name_len = example_names
150            .map(|name| name.as_ref().len().min(max_name_width))
151            .max()
152            .unwrap_or(0);
153    }
154
155    pub fn increment_failed(&self) {
156        let mut inner = self.inner.lock().unwrap();
157        inner.failed_examples += 1;
158    }
159
160    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
161    /// This should be used for any output that needs to appear above the status lines.
162    fn log(&self, message: &str) {
163        let mut inner = self.inner.lock().unwrap();
164        Self::clear_status_lines(&mut inner);
165
166        if !inner.last_line_is_logging {
167            let reset = "\x1b[0m";
168            let dim = "\x1b[2m";
169            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
170            eprintln!("{dim}{divider}{reset}");
171            inner.last_line_is_logging = true;
172        }
173
174        let max_width = inner.terminal_width.saturating_sub(MARGIN);
175        for line in message.lines() {
176            let truncated = truncate_to_visible_width(line, max_width);
177            if truncated.len() < line.len() {
178                eprintln!("{}", truncated);
179            } else {
180                eprintln!("{}", truncated);
181            }
182        }
183
184        Self::print_status_lines(&mut inner);
185    }
186
187    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
188        let mut inner = self.inner.lock().unwrap();
189
190        Self::clear_status_lines(&mut inner);
191
192        // Update max_example_name_len if not already set via set_max_example_name_len
193        if inner.max_example_name_len == 0 {
194            let max_name_width = inner
195                .terminal_width
196                .saturating_sub(MARGIN * 2)
197                .saturating_div(3)
198                .max(1);
199            inner.max_example_name_len = example_name.len().min(max_name_width);
200        }
201        inner.in_progress.insert(
202            example_name.to_string(),
203            InProgressTask {
204                step,
205                started_at: Instant::now(),
206                substatus: None,
207                info: None,
208            },
209        );
210
211        if inner.is_tty && inner.ticker.is_none() {
212            let progress = self.clone();
213            inner.ticker = Some(std::thread::spawn(move || {
214                loop {
215                    std::thread::sleep(STATUS_TICK_INTERVAL);
216
217                    let mut inner = progress.inner.lock().unwrap();
218                    if inner.in_progress.is_empty() {
219                        break;
220                    }
221
222                    Progress::clear_status_lines(&mut inner);
223                    Progress::print_status_lines(&mut inner);
224                }
225            }));
226        }
227
228        Self::print_status_lines(&mut inner);
229
230        StepProgress {
231            progress: self.clone(),
232            step,
233            example_name: example_name.to_string(),
234        }
235    }
236
237    fn finish(&self, step: Step, example_name: &str) {
238        let mut inner = self.inner.lock().unwrap();
239
240        let Some(task) = inner.in_progress.remove(example_name) else {
241            return;
242        };
243
244        if task.step == step {
245            let duration = task.started_at.elapsed();
246
247            // Skip logging for tasks that complete quickly (under 500ms)
248            let should_print = duration >= Duration::from_millis(500);
249
250            inner.completed.push(CompletedTask {
251                step: task.step,
252                example_name: example_name.to_string(),
253                duration,
254                info: task.info,
255            });
256
257            Self::clear_status_lines(&mut inner);
258            if should_print {
259                Self::print_logging_closing_divider(&mut inner);
260                if let Some(last_completed) = inner.completed.last() {
261                    Self::print_completed(&inner, last_completed);
262                }
263            }
264            Self::print_status_lines(&mut inner);
265        } else {
266            inner.in_progress.insert(example_name.to_string(), task);
267        }
268    }
269
270    fn print_logging_closing_divider(inner: &mut ProgressInner) {
271        if inner.last_line_is_logging {
272            let reset = "\x1b[0m";
273            let dim = "\x1b[2m";
274            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
275            eprintln!("{dim}{divider}{reset}");
276            inner.last_line_is_logging = false;
277        }
278    }
279
280    fn clear_status_lines(inner: &mut ProgressInner) {
281        if inner.is_tty && inner.status_lines_displayed > 0 {
282            // Move up and clear each line we previously displayed
283            for _ in 0..inner.status_lines_displayed {
284                eprint!("\x1b[A\x1b[K");
285            }
286            inner.status_lines_displayed = 0;
287        }
288    }
289
290    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
291        let duration = format_duration(task.duration);
292        let name_width = inner.max_example_name_len;
293        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
294
295        if inner.is_tty {
296            let reset = "\x1b[0m";
297            let bold = "\x1b[1m";
298            let dim = "\x1b[2m";
299
300            let yellow = "\x1b[33m";
301            let info_part = task
302                .info
303                .as_ref()
304                .map(|(s, style)| {
305                    if *style == InfoStyle::Warning {
306                        format!("{yellow}{s}{reset}")
307                    } else {
308                        s.to_string()
309                    }
310                })
311                .unwrap_or_default();
312
313            let prefix = format!(
314                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
315                color = task.step.color_code(),
316                label = task.step.label(),
317                name = truncated_name,
318            );
319
320            let duration_with_margin = format!("{duration} ");
321            let padding_needed = inner
322                .terminal_width
323                .saturating_sub(MARGIN)
324                .saturating_sub(duration_with_margin.len())
325                .saturating_sub(strip_ansi_len(&prefix));
326            let padding = " ".repeat(padding_needed);
327
328            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
329        } else {
330            let info_part = task
331                .info
332                .as_ref()
333                .map(|(s, _)| format!(" | {}", s))
334                .unwrap_or_default();
335
336            eprintln!(
337                "{label:>12} {name:<name_width$}{info_part} {duration}",
338                label = task.step.label(),
339                name = truncate_with_ellipsis(&task.example_name, name_width),
340            );
341        }
342    }
343
344    fn print_status_lines(inner: &mut ProgressInner) {
345        if !inner.is_tty || inner.in_progress.is_empty() {
346            inner.status_lines_displayed = 0;
347            return;
348        }
349
350        let reset = "\x1b[0m";
351        let bold = "\x1b[1m";
352        let dim = "\x1b[2m";
353
354        // Build the done/in-progress/total label
355        let done_count = inner.completed_examples;
356        let in_progress_count = inner.in_progress.len();
357        let failed_count = inner.failed_examples;
358
359        let failed_label = if failed_count > 0 {
360            format!(" {} failed ", failed_count)
361        } else {
362            String::new()
363        };
364
365        let range_label = format!(
366            " {}/{}/{} ",
367            done_count, in_progress_count, inner.total_examples
368        );
369
370        // Print a divider line with failed count on left, range label on right
371        let failed_visible_len = strip_ansi_len(&failed_label);
372        let range_visible_len = range_label.len();
373        let middle_divider_len = inner
374            .terminal_width
375            .saturating_sub(MARGIN * 2)
376            .saturating_sub(failed_visible_len)
377            .saturating_sub(range_visible_len);
378        let left_divider = "".repeat(MARGIN);
379        let middle_divider = "".repeat(middle_divider_len);
380        let right_divider = "".repeat(MARGIN);
381        eprintln!(
382            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
383        );
384
385        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
386        tasks.sort_by_key(|(name, _)| *name);
387
388        let total_tasks = tasks.len();
389        let mut lines_printed = 0;
390
391        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
392            let elapsed = format_duration(task.started_at.elapsed());
393            let substatus_part = task
394                .substatus
395                .as_ref()
396                .map(|s| truncate_with_ellipsis(s, 30))
397                .unwrap_or_default();
398
399            let step_label = task.step.label();
400            let step_color = task.step.color_code();
401            let name_width = inner.max_example_name_len;
402            let truncated_name = truncate_with_ellipsis(name, name_width);
403
404            let prefix = format!(
405                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
406                name = truncated_name,
407            );
408
409            let duration_with_margin = format!("{elapsed} ");
410            let padding_needed = inner
411                .terminal_width
412                .saturating_sub(MARGIN)
413                .saturating_sub(duration_with_margin.len())
414                .saturating_sub(strip_ansi_len(&prefix));
415            let padding = " ".repeat(padding_needed);
416
417            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
418            lines_printed += 1;
419        }
420
421        // Show "+N more" on its own line if there are more tasks
422        if total_tasks > MAX_STATUS_LINES {
423            let remaining = total_tasks - MAX_STATUS_LINES;
424            eprintln!("{:>12} +{remaining} more", "");
425            lines_printed += 1;
426        }
427
428        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
429        let _ = std::io::stderr().flush();
430    }
431
432    pub fn finalize(&self) {
433        let ticker = {
434            let mut inner = self.inner.lock().unwrap();
435            inner.ticker.take()
436        };
437
438        if let Some(ticker) = ticker {
439            let _ = ticker.join();
440        }
441
442        let mut inner = self.inner.lock().unwrap();
443        Self::clear_status_lines(&mut inner);
444
445        // Print summary if there were failures
446        if inner.failed_examples > 0 {
447            let total_examples = inner.total_examples;
448            let percentage = if total_examples > 0 {
449                inner.failed_examples as f64 / total_examples as f64 * 100.0
450            } else {
451                0.0
452            };
453            let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
454            eprintln!(
455                "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
456                inner.failed_examples,
457                total_examples,
458                percentage,
459                failed_jsonl_path.display()
460            );
461        }
462    }
463}
464
465pub struct ExampleProgress {
466    progress: Arc<Progress>,
467    example_name: String,
468}
469
470impl ExampleProgress {
471    pub fn start(&self, step: Step) -> StepProgress {
472        self.progress.start(step, &self.example_name)
473    }
474}
475
476impl Drop for ExampleProgress {
477    fn drop(&mut self) {
478        self.progress.increment_completed();
479    }
480}
481
482pub struct StepProgress {
483    progress: Arc<Progress>,
484    step: Step,
485    example_name: String,
486}
487
488impl StepProgress {
489    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
490        let mut inner = self.progress.inner.lock().unwrap();
491        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
492            task.substatus = Some(substatus.into().into_owned());
493            Progress::clear_status_lines(&mut inner);
494            Progress::print_status_lines(&mut inner);
495        }
496    }
497
498    pub fn clear_substatus(&self) {
499        let mut inner = self.progress.inner.lock().unwrap();
500        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
501            task.substatus = None;
502            Progress::clear_status_lines(&mut inner);
503            Progress::print_status_lines(&mut inner);
504        }
505    }
506
507    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
508        let mut inner = self.progress.inner.lock().unwrap();
509        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
510            task.info = Some((info.into(), style));
511        }
512    }
513}
514
515impl Drop for StepProgress {
516    fn drop(&mut self) {
517        self.progress.finish(self.step, &self.example_name);
518    }
519}
520
521struct ProgressLogger;
522
523impl Log for ProgressLogger {
524    fn enabled(&self, metadata: &Metadata) -> bool {
525        metadata.level() <= Level::Info
526    }
527
528    fn log(&self, record: &Record) {
529        if !self.enabled(record.metadata()) {
530            return;
531        }
532
533        let level_color = match record.level() {
534            Level::Error => "\x1b[31m",
535            Level::Warn => "\x1b[33m",
536            Level::Info => "\x1b[32m",
537            Level::Debug => "\x1b[34m",
538            Level::Trace => "\x1b[35m",
539        };
540        let reset = "\x1b[0m";
541        let bold = "\x1b[1m";
542
543        let level_label = match record.level() {
544            Level::Error => "Error",
545            Level::Warn => "Warn",
546            Level::Info => "Info",
547            Level::Debug => "Debug",
548            Level::Trace => "Trace",
549        };
550
551        let message = format!(
552            "{bold}{level_color}{level_label:>12}{reset} {}",
553            record.args()
554        );
555
556        if let Some(progress) = GLOBAL.get() {
557            progress.log(&message);
558        } else {
559            eprintln!("{}", message);
560        }
561    }
562
563    fn flush(&self) {
564        let _ = std::io::stderr().flush();
565    }
566}
567
568#[cfg(unix)]
569fn get_terminal_width() -> usize {
570    unsafe {
571        let mut winsize: libc::winsize = std::mem::zeroed();
572        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
573            && winsize.ws_col > 0
574        {
575            winsize.ws_col as usize
576        } else {
577            80
578        }
579    }
580}
581
582#[cfg(not(unix))]
583fn get_terminal_width() -> usize {
584    80
585}
586
587fn strip_ansi_len(s: &str) -> usize {
588    let mut len = 0;
589    let mut in_escape = false;
590    for c in s.chars() {
591        if c == '\x1b' {
592            in_escape = true;
593        } else if in_escape {
594            if c == 'm' {
595                in_escape = false;
596            }
597        } else {
598            len += 1;
599        }
600    }
601    len
602}
603
604fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
605    if s.len() <= max_len {
606        Cow::Borrowed(s)
607    } else {
608        Cow::Owned(format!("{}", &s[..max_len.saturating_sub(1)]))
609    }
610}
611
612fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
613    let mut visible_len = 0;
614    let mut in_escape = false;
615    let mut last_byte_index = 0;
616    for (byte_index, c) in s.char_indices() {
617        if c == '\x1b' {
618            in_escape = true;
619        } else if in_escape {
620            if c == 'm' {
621                in_escape = false;
622            }
623        } else {
624            if visible_len >= max_visible_len {
625                return &s[..last_byte_index];
626            }
627            visible_len += 1;
628        }
629        last_byte_index = byte_index + c.len_utf8();
630    }
631    s
632}
633
634fn format_duration(duration: Duration) -> String {
635    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
636
637    let millis = duration.as_millis() as f32;
638    if millis < 1000.0 {
639        format!("{}ms", millis)
640    } else if millis < MINUTE_IN_MILLIS {
641        format!("{:.1}s", millis / 1_000.0)
642    } else {
643        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
644    }
645}