progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use crate::paths::RUN_DIR;
 10
 11use log::{Level, Log, Metadata, Record};
 12
 13pub struct Progress {
 14    inner: Mutex<ProgressInner>,
 15}
 16
 17struct ProgressInner {
 18    completed: Vec<CompletedTask>,
 19    in_progress: HashMap<String, InProgressTask>,
 20    is_tty: bool,
 21    terminal_width: usize,
 22    max_example_name_len: usize,
 23    status_lines_displayed: usize,
 24    total_examples: usize,
 25    failed_examples: usize,
 26    last_line_is_logging: bool,
 27    ticker: Option<std::thread::JoinHandle<()>>,
 28}
 29
 30#[derive(Clone)]
 31struct InProgressTask {
 32    step: Step,
 33    started_at: Instant,
 34    substatus: Option<String>,
 35    info: Option<(String, InfoStyle)>,
 36}
 37
 38struct CompletedTask {
 39    step: Step,
 40    example_name: String,
 41    duration: Duration,
 42    info: Option<(String, InfoStyle)>,
 43}
 44
 45#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 46pub enum Step {
 47    LoadProject,
 48    Context,
 49    FormatPrompt,
 50    Predict,
 51    Score,
 52    Synthesize,
 53    PullExamples,
 54}
 55
 56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 57pub enum InfoStyle {
 58    Normal,
 59    Warning,
 60}
 61
 62impl Step {
 63    pub fn label(&self) -> &'static str {
 64        match self {
 65            Step::LoadProject => "Load",
 66            Step::Context => "Context",
 67            Step::FormatPrompt => "Format",
 68            Step::Predict => "Predict",
 69            Step::Score => "Score",
 70            Step::Synthesize => "Synthesize",
 71            Step::PullExamples => "Pull",
 72        }
 73    }
 74
 75    fn color_code(&self) -> &'static str {
 76        match self {
 77            Step::LoadProject => "\x1b[33m",
 78            Step::Context => "\x1b[35m",
 79            Step::FormatPrompt => "\x1b[34m",
 80            Step::Predict => "\x1b[32m",
 81            Step::Score => "\x1b[31m",
 82            Step::Synthesize => "\x1b[36m",
 83            Step::PullExamples => "\x1b[36m",
 84        }
 85    }
 86}
 87
 88static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 89static LOGGER: ProgressLogger = ProgressLogger;
 90
 91const MARGIN: usize = 4;
 92const MAX_STATUS_LINES: usize = 10;
 93const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 94
 95impl Progress {
 96    /// Returns the global Progress instance, initializing it if necessary.
 97    pub fn global() -> Arc<Progress> {
 98        GLOBAL
 99            .get_or_init(|| {
100                let progress = Arc::new(Self {
101                    inner: Mutex::new(ProgressInner {
102                        completed: Vec::new(),
103                        in_progress: HashMap::new(),
104                        is_tty: std::env::var("NO_COLOR").is_err()
105                            && std::io::stderr().is_terminal(),
106                        terminal_width: get_terminal_width(),
107                        max_example_name_len: 0,
108                        status_lines_displayed: 0,
109                        total_examples: 0,
110                        failed_examples: 0,
111                        last_line_is_logging: false,
112                        ticker: None,
113                    }),
114                });
115                let _ = log::set_logger(&LOGGER);
116                log::set_max_level(log::LevelFilter::Info);
117                progress
118            })
119            .clone()
120    }
121
122    pub fn set_total_examples(&self, total: usize) {
123        let mut inner = self.inner.lock().unwrap();
124        inner.total_examples = total;
125    }
126
127    pub fn increment_failed(&self) {
128        let mut inner = self.inner.lock().unwrap();
129        inner.failed_examples += 1;
130    }
131
132    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
133    /// This should be used for any output that needs to appear above the status lines.
134    fn log(&self, message: &str) {
135        let mut inner = self.inner.lock().unwrap();
136        Self::clear_status_lines(&mut inner);
137
138        if !inner.last_line_is_logging {
139            let reset = "\x1b[0m";
140            let dim = "\x1b[2m";
141            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
142            eprintln!("{dim}{divider}{reset}");
143            inner.last_line_is_logging = true;
144        }
145
146        let max_width = inner.terminal_width.saturating_sub(MARGIN);
147        for line in message.lines() {
148            let truncated = truncate_to_visible_width(line, max_width);
149            if truncated.len() < line.len() {
150                eprintln!("{}", truncated);
151            } else {
152                eprintln!("{}", truncated);
153            }
154        }
155
156        Self::print_status_lines(&mut inner);
157    }
158
159    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
160        let mut inner = self.inner.lock().unwrap();
161
162        Self::clear_status_lines(&mut inner);
163
164        let max_name_width = inner
165            .terminal_width
166            .saturating_sub(MARGIN * 2)
167            .saturating_div(3)
168            .max(1);
169        inner.max_example_name_len = inner
170            .max_example_name_len
171            .max(example_name.len().min(max_name_width));
172        inner.in_progress.insert(
173            example_name.to_string(),
174            InProgressTask {
175                step,
176                started_at: Instant::now(),
177                substatus: None,
178                info: None,
179            },
180        );
181
182        if inner.is_tty && inner.ticker.is_none() {
183            let progress = self.clone();
184            inner.ticker = Some(std::thread::spawn(move || {
185                loop {
186                    std::thread::sleep(STATUS_TICK_INTERVAL);
187
188                    let mut inner = progress.inner.lock().unwrap();
189                    if inner.in_progress.is_empty() {
190                        break;
191                    }
192
193                    Progress::clear_status_lines(&mut inner);
194                    Progress::print_status_lines(&mut inner);
195                }
196            }));
197        }
198
199        Self::print_status_lines(&mut inner);
200
201        StepProgress {
202            progress: self.clone(),
203            step,
204            example_name: example_name.to_string(),
205        }
206    }
207
208    fn finish(&self, step: Step, example_name: &str) {
209        let mut inner = self.inner.lock().unwrap();
210
211        let Some(task) = inner.in_progress.remove(example_name) else {
212            return;
213        };
214
215        if task.step == step {
216            inner.completed.push(CompletedTask {
217                step: task.step,
218                example_name: example_name.to_string(),
219                duration: task.started_at.elapsed(),
220                info: task.info,
221            });
222
223            Self::clear_status_lines(&mut inner);
224            Self::print_logging_closing_divider(&mut inner);
225            if let Some(last_completed) = inner.completed.last() {
226                Self::print_completed(&inner, last_completed);
227            }
228            Self::print_status_lines(&mut inner);
229        } else {
230            inner.in_progress.insert(example_name.to_string(), task);
231        }
232    }
233
234    fn print_logging_closing_divider(inner: &mut ProgressInner) {
235        if inner.last_line_is_logging {
236            let reset = "\x1b[0m";
237            let dim = "\x1b[2m";
238            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
239            eprintln!("{dim}{divider}{reset}");
240            inner.last_line_is_logging = false;
241        }
242    }
243
244    fn clear_status_lines(inner: &mut ProgressInner) {
245        if inner.is_tty && inner.status_lines_displayed > 0 {
246            // Move up and clear each line we previously displayed
247            for _ in 0..inner.status_lines_displayed {
248                eprint!("\x1b[A\x1b[K");
249            }
250            let _ = std::io::stderr().flush();
251            inner.status_lines_displayed = 0;
252        }
253    }
254
255    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
256        let duration = format_duration(task.duration);
257        let name_width = inner.max_example_name_len;
258        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
259
260        if inner.is_tty {
261            let reset = "\x1b[0m";
262            let bold = "\x1b[1m";
263            let dim = "\x1b[2m";
264
265            let yellow = "\x1b[33m";
266            let info_part = task
267                .info
268                .as_ref()
269                .map(|(s, style)| {
270                    if *style == InfoStyle::Warning {
271                        format!("{yellow}{s}{reset}")
272                    } else {
273                        s.to_string()
274                    }
275                })
276                .unwrap_or_default();
277
278            let prefix = format!(
279                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
280                color = task.step.color_code(),
281                label = task.step.label(),
282                name = truncated_name,
283            );
284
285            let duration_with_margin = format!("{duration} ");
286            let padding_needed = inner
287                .terminal_width
288                .saturating_sub(MARGIN)
289                .saturating_sub(duration_with_margin.len())
290                .saturating_sub(strip_ansi_len(&prefix));
291            let padding = " ".repeat(padding_needed);
292
293            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
294        } else {
295            let info_part = task
296                .info
297                .as_ref()
298                .map(|(s, _)| format!(" | {}", s))
299                .unwrap_or_default();
300
301            eprintln!(
302                "{label:>12} {name:<name_width$}{info_part} {duration}",
303                label = task.step.label(),
304                name = truncate_with_ellipsis(&task.example_name, name_width),
305            );
306        }
307    }
308
309    fn print_status_lines(inner: &mut ProgressInner) {
310        if !inner.is_tty || inner.in_progress.is_empty() {
311            inner.status_lines_displayed = 0;
312            return;
313        }
314
315        let reset = "\x1b[0m";
316        let bold = "\x1b[1m";
317        let dim = "\x1b[2m";
318
319        // Build the done/in-progress/total label
320        let done_count = inner.completed.len();
321        let in_progress_count = inner.in_progress.len();
322        let failed_count = inner.failed_examples;
323
324        let failed_label = if failed_count > 0 {
325            format!(" {} failed ", failed_count)
326        } else {
327            String::new()
328        };
329
330        let range_label = format!(
331            " {}/{}/{} ",
332            done_count, in_progress_count, inner.total_examples
333        );
334
335        // Print a divider line with failed count on left, range label on right
336        let failed_visible_len = strip_ansi_len(&failed_label);
337        let range_visible_len = range_label.len();
338        let middle_divider_len = inner
339            .terminal_width
340            .saturating_sub(MARGIN * 2)
341            .saturating_sub(failed_visible_len)
342            .saturating_sub(range_visible_len);
343        let left_divider = "".repeat(MARGIN);
344        let middle_divider = "".repeat(middle_divider_len);
345        let right_divider = "".repeat(MARGIN);
346        eprintln!(
347            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
348        );
349
350        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
351        tasks.sort_by_key(|(name, _)| *name);
352
353        let total_tasks = tasks.len();
354        let mut lines_printed = 0;
355
356        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
357            let elapsed = format_duration(task.started_at.elapsed());
358            let substatus_part = task
359                .substatus
360                .as_ref()
361                .map(|s| truncate_with_ellipsis(s, 30))
362                .unwrap_or_default();
363
364            let step_label = task.step.label();
365            let step_color = task.step.color_code();
366            let name_width = inner.max_example_name_len;
367            let truncated_name = truncate_with_ellipsis(name, name_width);
368
369            let prefix = format!(
370                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
371                name = truncated_name,
372            );
373
374            let duration_with_margin = format!("{elapsed} ");
375            let padding_needed = inner
376                .terminal_width
377                .saturating_sub(MARGIN)
378                .saturating_sub(duration_with_margin.len())
379                .saturating_sub(strip_ansi_len(&prefix));
380            let padding = " ".repeat(padding_needed);
381
382            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
383            lines_printed += 1;
384        }
385
386        // Show "+N more" on its own line if there are more tasks
387        if total_tasks > MAX_STATUS_LINES {
388            let remaining = total_tasks - MAX_STATUS_LINES;
389            eprintln!("{:>12} +{remaining} more", "");
390            lines_printed += 1;
391        }
392
393        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
394        let _ = std::io::stderr().flush();
395    }
396
397    pub fn finalize(&self) {
398        let ticker = {
399            let mut inner = self.inner.lock().unwrap();
400            inner.ticker.take()
401        };
402
403        if let Some(ticker) = ticker {
404            let _ = ticker.join();
405        }
406
407        let mut inner = self.inner.lock().unwrap();
408        Self::clear_status_lines(&mut inner);
409
410        // Print summary if there were failures
411        if inner.failed_examples > 0 {
412            let total_examples = inner.total_examples;
413            let percentage = if total_examples > 0 {
414                inner.failed_examples as f64 / total_examples as f64 * 100.0
415            } else {
416                0.0
417            };
418            let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
419            eprintln!(
420                "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
421                inner.failed_examples,
422                total_examples,
423                percentage,
424                failed_jsonl_path.display()
425            );
426        }
427    }
428}
429
430pub struct StepProgress {
431    progress: Arc<Progress>,
432    step: Step,
433    example_name: String,
434}
435
436impl StepProgress {
437    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
438        let mut inner = self.progress.inner.lock().unwrap();
439        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
440            task.substatus = Some(substatus.into().into_owned());
441            Progress::clear_status_lines(&mut inner);
442            Progress::print_status_lines(&mut inner);
443        }
444    }
445
446    pub fn clear_substatus(&self) {
447        let mut inner = self.progress.inner.lock().unwrap();
448        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
449            task.substatus = None;
450            Progress::clear_status_lines(&mut inner);
451            Progress::print_status_lines(&mut inner);
452        }
453    }
454
455    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
456        let mut inner = self.progress.inner.lock().unwrap();
457        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
458            task.info = Some((info.into(), style));
459        }
460    }
461}
462
463impl Drop for StepProgress {
464    fn drop(&mut self) {
465        self.progress.finish(self.step, &self.example_name);
466    }
467}
468
469struct ProgressLogger;
470
471impl Log for ProgressLogger {
472    fn enabled(&self, metadata: &Metadata) -> bool {
473        metadata.level() <= Level::Info
474    }
475
476    fn log(&self, record: &Record) {
477        if !self.enabled(record.metadata()) {
478            return;
479        }
480
481        let level_color = match record.level() {
482            Level::Error => "\x1b[31m",
483            Level::Warn => "\x1b[33m",
484            Level::Info => "\x1b[32m",
485            Level::Debug => "\x1b[34m",
486            Level::Trace => "\x1b[35m",
487        };
488        let reset = "\x1b[0m";
489        let bold = "\x1b[1m";
490
491        let level_label = match record.level() {
492            Level::Error => "Error",
493            Level::Warn => "Warn",
494            Level::Info => "Info",
495            Level::Debug => "Debug",
496            Level::Trace => "Trace",
497        };
498
499        let message = format!(
500            "{bold}{level_color}{level_label:>12}{reset} {}",
501            record.args()
502        );
503
504        if let Some(progress) = GLOBAL.get() {
505            progress.log(&message);
506        } else {
507            eprintln!("{}", message);
508        }
509    }
510
511    fn flush(&self) {
512        let _ = std::io::stderr().flush();
513    }
514}
515
516#[cfg(unix)]
517fn get_terminal_width() -> usize {
518    unsafe {
519        let mut winsize: libc::winsize = std::mem::zeroed();
520        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
521            && winsize.ws_col > 0
522        {
523            winsize.ws_col as usize
524        } else {
525            80
526        }
527    }
528}
529
530#[cfg(not(unix))]
531fn get_terminal_width() -> usize {
532    80
533}
534
535fn strip_ansi_len(s: &str) -> usize {
536    let mut len = 0;
537    let mut in_escape = false;
538    for c in s.chars() {
539        if c == '\x1b' {
540            in_escape = true;
541        } else if in_escape {
542            if c == 'm' {
543                in_escape = false;
544            }
545        } else {
546            len += 1;
547        }
548    }
549    len
550}
551
552fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
553    if s.len() <= max_len {
554        Cow::Borrowed(s)
555    } else {
556        Cow::Owned(format!("{}", &s[..max_len.saturating_sub(1)]))
557    }
558}
559
560fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
561    let mut visible_len = 0;
562    let mut in_escape = false;
563    let mut last_byte_index = 0;
564    for (byte_index, c) in s.char_indices() {
565        if c == '\x1b' {
566            in_escape = true;
567        } else if in_escape {
568            if c == 'm' {
569                in_escape = false;
570            }
571        } else {
572            if visible_len >= max_visible_len {
573                return &s[..last_byte_index];
574            }
575            visible_len += 1;
576        }
577        last_byte_index = byte_index + c.len_utf8();
578    }
579    s
580}
581
582fn format_duration(duration: Duration) -> String {
583    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
584
585    let millis = duration.as_millis() as f32;
586    if millis < 1000.0 {
587        format!("{}ms", millis)
588    } else if millis < MINUTE_IN_MILLIS {
589        format!("{:.1}s", millis / 1_000.0)
590    } else {
591        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
592    }
593}