progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use log::{Level, Log, Metadata, Record};
 10
 11pub struct Progress {
 12    inner: Mutex<ProgressInner>,
 13}
 14
 15struct ProgressInner {
 16    completed: Vec<CompletedTask>,
 17    in_progress: HashMap<String, InProgressTask>,
 18    is_tty: bool,
 19    terminal_width: usize,
 20    max_example_name_len: usize,
 21    status_lines_displayed: usize,
 22    total_examples: usize,
 23    failed_examples: usize,
 24    last_line_is_logging: bool,
 25    ticker: Option<std::thread::JoinHandle<()>>,
 26}
 27
 28#[derive(Clone)]
 29struct InProgressTask {
 30    step: Step,
 31    started_at: Instant,
 32    substatus: Option<String>,
 33    info: Option<(String, InfoStyle)>,
 34}
 35
 36struct CompletedTask {
 37    step: Step,
 38    example_name: String,
 39    duration: Duration,
 40    info: Option<(String, InfoStyle)>,
 41}
 42
 43#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 44pub enum Step {
 45    LoadProject,
 46    Context,
 47    FormatPrompt,
 48    Predict,
 49    Score,
 50    Synthesize,
 51    PullExamples,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 55pub enum InfoStyle {
 56    Normal,
 57    Warning,
 58}
 59
 60impl Step {
 61    pub fn label(&self) -> &'static str {
 62        match self {
 63            Step::LoadProject => "Load",
 64            Step::Context => "Context",
 65            Step::FormatPrompt => "Format",
 66            Step::Predict => "Predict",
 67            Step::Score => "Score",
 68            Step::Synthesize => "Synthesize",
 69            Step::PullExamples => "Pull",
 70        }
 71    }
 72
 73    fn color_code(&self) -> &'static str {
 74        match self {
 75            Step::LoadProject => "\x1b[33m",
 76            Step::Context => "\x1b[35m",
 77            Step::FormatPrompt => "\x1b[34m",
 78            Step::Predict => "\x1b[32m",
 79            Step::Score => "\x1b[31m",
 80            Step::Synthesize => "\x1b[36m",
 81            Step::PullExamples => "\x1b[36m",
 82        }
 83    }
 84}
 85
 86static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 87static LOGGER: ProgressLogger = ProgressLogger;
 88
 89const MARGIN: usize = 4;
 90const MAX_STATUS_LINES: usize = 10;
 91const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 92
 93impl Progress {
 94    /// Returns the global Progress instance, initializing it if necessary.
 95    pub fn global() -> Arc<Progress> {
 96        GLOBAL
 97            .get_or_init(|| {
 98                let progress = Arc::new(Self {
 99                    inner: Mutex::new(ProgressInner {
100                        completed: Vec::new(),
101                        in_progress: HashMap::new(),
102                        is_tty: std::env::var("NO_COLOR").is_err()
103                            && std::io::stderr().is_terminal(),
104                        terminal_width: get_terminal_width(),
105                        max_example_name_len: 0,
106                        status_lines_displayed: 0,
107                        total_examples: 0,
108                        failed_examples: 0,
109                        last_line_is_logging: false,
110                        ticker: None,
111                    }),
112                });
113                let _ = log::set_logger(&LOGGER);
114                log::set_max_level(log::LevelFilter::Info);
115                progress
116            })
117            .clone()
118    }
119
120    pub fn set_total_examples(&self, total: usize) {
121        let mut inner = self.inner.lock().unwrap();
122        inner.total_examples = total;
123    }
124
125    pub fn increment_failed(&self) {
126        let mut inner = self.inner.lock().unwrap();
127        inner.failed_examples += 1;
128    }
129
130    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
131    /// This should be used for any output that needs to appear above the status lines.
132    fn log(&self, message: &str) {
133        let mut inner = self.inner.lock().unwrap();
134        Self::clear_status_lines(&mut inner);
135
136        if !inner.last_line_is_logging {
137            let reset = "\x1b[0m";
138            let dim = "\x1b[2m";
139            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
140            eprintln!("{dim}{divider}{reset}");
141            inner.last_line_is_logging = true;
142        }
143
144        let max_width = inner.terminal_width.saturating_sub(MARGIN);
145        for line in message.lines() {
146            let truncated = truncate_to_visible_width(line, max_width);
147            if truncated.len() < line.len() {
148                eprintln!("{}", truncated);
149            } else {
150                eprintln!("{}", truncated);
151            }
152        }
153
154        Self::print_status_lines(&mut inner);
155    }
156
157    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
158        let mut inner = self.inner.lock().unwrap();
159
160        Self::clear_status_lines(&mut inner);
161
162        let max_name_width = inner
163            .terminal_width
164            .saturating_sub(MARGIN * 2)
165            .saturating_div(3)
166            .max(1);
167        inner.max_example_name_len = inner
168            .max_example_name_len
169            .max(example_name.len().min(max_name_width));
170        inner.in_progress.insert(
171            example_name.to_string(),
172            InProgressTask {
173                step,
174                started_at: Instant::now(),
175                substatus: None,
176                info: None,
177            },
178        );
179
180        if inner.is_tty && inner.ticker.is_none() {
181            let progress = self.clone();
182            inner.ticker = Some(std::thread::spawn(move || {
183                loop {
184                    std::thread::sleep(STATUS_TICK_INTERVAL);
185
186                    let mut inner = progress.inner.lock().unwrap();
187                    if inner.in_progress.is_empty() {
188                        break;
189                    }
190
191                    Progress::clear_status_lines(&mut inner);
192                    Progress::print_status_lines(&mut inner);
193                }
194            }));
195        }
196
197        Self::print_status_lines(&mut inner);
198
199        StepProgress {
200            progress: self.clone(),
201            step,
202            example_name: example_name.to_string(),
203        }
204    }
205
206    fn finish(&self, step: Step, example_name: &str) {
207        let mut inner = self.inner.lock().unwrap();
208
209        let Some(task) = inner.in_progress.remove(example_name) else {
210            return;
211        };
212
213        if task.step == step {
214            inner.completed.push(CompletedTask {
215                step: task.step,
216                example_name: example_name.to_string(),
217                duration: task.started_at.elapsed(),
218                info: task.info,
219            });
220
221            Self::clear_status_lines(&mut inner);
222            Self::print_logging_closing_divider(&mut inner);
223            if let Some(last_completed) = inner.completed.last() {
224                Self::print_completed(&inner, last_completed);
225            }
226            Self::print_status_lines(&mut inner);
227        } else {
228            inner.in_progress.insert(example_name.to_string(), task);
229        }
230    }
231
232    fn print_logging_closing_divider(inner: &mut ProgressInner) {
233        if inner.last_line_is_logging {
234            let reset = "\x1b[0m";
235            let dim = "\x1b[2m";
236            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
237            eprintln!("{dim}{divider}{reset}");
238            inner.last_line_is_logging = false;
239        }
240    }
241
242    fn clear_status_lines(inner: &mut ProgressInner) {
243        if inner.is_tty && inner.status_lines_displayed > 0 {
244            // Move up and clear each line we previously displayed
245            for _ in 0..inner.status_lines_displayed {
246                eprint!("\x1b[A\x1b[K");
247            }
248            let _ = std::io::stderr().flush();
249            inner.status_lines_displayed = 0;
250        }
251    }
252
253    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
254        let duration = format_duration(task.duration);
255        let name_width = inner.max_example_name_len;
256        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
257
258        if inner.is_tty {
259            let reset = "\x1b[0m";
260            let bold = "\x1b[1m";
261            let dim = "\x1b[2m";
262
263            let yellow = "\x1b[33m";
264            let info_part = task
265                .info
266                .as_ref()
267                .map(|(s, style)| {
268                    if *style == InfoStyle::Warning {
269                        format!("{yellow}{s}{reset}")
270                    } else {
271                        s.to_string()
272                    }
273                })
274                .unwrap_or_default();
275
276            let prefix = format!(
277                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
278                color = task.step.color_code(),
279                label = task.step.label(),
280                name = truncated_name,
281            );
282
283            let duration_with_margin = format!("{duration} ");
284            let padding_needed = inner
285                .terminal_width
286                .saturating_sub(MARGIN)
287                .saturating_sub(duration_with_margin.len())
288                .saturating_sub(strip_ansi_len(&prefix));
289            let padding = " ".repeat(padding_needed);
290
291            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
292        } else {
293            let info_part = task
294                .info
295                .as_ref()
296                .map(|(s, _)| format!(" | {}", s))
297                .unwrap_or_default();
298
299            eprintln!(
300                "{label:>12} {name:<name_width$}{info_part} {duration}",
301                label = task.step.label(),
302                name = truncate_with_ellipsis(&task.example_name, name_width),
303            );
304        }
305    }
306
307    fn print_status_lines(inner: &mut ProgressInner) {
308        if !inner.is_tty || inner.in_progress.is_empty() {
309            inner.status_lines_displayed = 0;
310            return;
311        }
312
313        let reset = "\x1b[0m";
314        let bold = "\x1b[1m";
315        let dim = "\x1b[2m";
316
317        // Build the done/in-progress/total label
318        let done_count = inner.completed.len();
319        let in_progress_count = inner.in_progress.len();
320        let failed_count = inner.failed_examples;
321
322        let failed_label = if failed_count > 0 {
323            format!(" {} failed ", failed_count)
324        } else {
325            String::new()
326        };
327
328        let range_label = format!(
329            " {}/{}/{} ",
330            done_count, in_progress_count, inner.total_examples
331        );
332
333        // Print a divider line with failed count on left, range label on right
334        let failed_visible_len = strip_ansi_len(&failed_label);
335        let range_visible_len = range_label.len();
336        let middle_divider_len = inner
337            .terminal_width
338            .saturating_sub(MARGIN * 2)
339            .saturating_sub(failed_visible_len)
340            .saturating_sub(range_visible_len);
341        let left_divider = "".repeat(MARGIN);
342        let middle_divider = "".repeat(middle_divider_len);
343        let right_divider = "".repeat(MARGIN);
344        eprintln!(
345            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
346        );
347
348        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
349        tasks.sort_by_key(|(name, _)| *name);
350
351        let total_tasks = tasks.len();
352        let mut lines_printed = 0;
353
354        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
355            let elapsed = format_duration(task.started_at.elapsed());
356            let substatus_part = task
357                .substatus
358                .as_ref()
359                .map(|s| truncate_with_ellipsis(s, 30))
360                .unwrap_or_default();
361
362            let step_label = task.step.label();
363            let step_color = task.step.color_code();
364            let name_width = inner.max_example_name_len;
365            let truncated_name = truncate_with_ellipsis(name, name_width);
366
367            let prefix = format!(
368                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
369                name = truncated_name,
370            );
371
372            let duration_with_margin = format!("{elapsed} ");
373            let padding_needed = inner
374                .terminal_width
375                .saturating_sub(MARGIN)
376                .saturating_sub(duration_with_margin.len())
377                .saturating_sub(strip_ansi_len(&prefix));
378            let padding = " ".repeat(padding_needed);
379
380            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
381            lines_printed += 1;
382        }
383
384        // Show "+N more" on its own line if there are more tasks
385        if total_tasks > MAX_STATUS_LINES {
386            let remaining = total_tasks - MAX_STATUS_LINES;
387            eprintln!("{:>12} +{remaining} more", "");
388            lines_printed += 1;
389        }
390
391        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
392        let _ = std::io::stderr().flush();
393    }
394
395    pub fn finalize(&self) {
396        let ticker = {
397            let mut inner = self.inner.lock().unwrap();
398            inner.ticker.take()
399        };
400
401        if let Some(ticker) = ticker {
402            let _ = ticker.join();
403        }
404
405        let mut inner = self.inner.lock().unwrap();
406        Self::clear_status_lines(&mut inner);
407
408        // Print summary if there were failures
409        if inner.failed_examples > 0 {
410            let total_examples = inner.total_examples;
411            let percentage = if total_examples > 0 {
412                inner.failed_examples as f64 / total_examples as f64 * 100.0
413            } else {
414                0.0
415            };
416            eprintln!(
417                "\n{} of {} examples failed ({:.1}%)",
418                inner.failed_examples, total_examples, percentage
419            );
420        }
421    }
422}
423
424pub struct StepProgress {
425    progress: Arc<Progress>,
426    step: Step,
427    example_name: String,
428}
429
430impl StepProgress {
431    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
432        let mut inner = self.progress.inner.lock().unwrap();
433        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
434            task.substatus = Some(substatus.into().into_owned());
435            Progress::clear_status_lines(&mut inner);
436            Progress::print_status_lines(&mut inner);
437        }
438    }
439
440    pub fn clear_substatus(&self) {
441        let mut inner = self.progress.inner.lock().unwrap();
442        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
443            task.substatus = None;
444            Progress::clear_status_lines(&mut inner);
445            Progress::print_status_lines(&mut inner);
446        }
447    }
448
449    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
450        let mut inner = self.progress.inner.lock().unwrap();
451        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
452            task.info = Some((info.into(), style));
453        }
454    }
455}
456
457impl Drop for StepProgress {
458    fn drop(&mut self) {
459        self.progress.finish(self.step, &self.example_name);
460    }
461}
462
463struct ProgressLogger;
464
465impl Log for ProgressLogger {
466    fn enabled(&self, metadata: &Metadata) -> bool {
467        metadata.level() <= Level::Info
468    }
469
470    fn log(&self, record: &Record) {
471        if !self.enabled(record.metadata()) {
472            return;
473        }
474
475        let level_color = match record.level() {
476            Level::Error => "\x1b[31m",
477            Level::Warn => "\x1b[33m",
478            Level::Info => "\x1b[32m",
479            Level::Debug => "\x1b[34m",
480            Level::Trace => "\x1b[35m",
481        };
482        let reset = "\x1b[0m";
483        let bold = "\x1b[1m";
484
485        let level_label = match record.level() {
486            Level::Error => "Error",
487            Level::Warn => "Warn",
488            Level::Info => "Info",
489            Level::Debug => "Debug",
490            Level::Trace => "Trace",
491        };
492
493        let message = format!(
494            "{bold}{level_color}{level_label:>12}{reset} {}",
495            record.args()
496        );
497
498        if let Some(progress) = GLOBAL.get() {
499            progress.log(&message);
500        } else {
501            eprintln!("{}", message);
502        }
503    }
504
505    fn flush(&self) {
506        let _ = std::io::stderr().flush();
507    }
508}
509
510#[cfg(unix)]
511fn get_terminal_width() -> usize {
512    unsafe {
513        let mut winsize: libc::winsize = std::mem::zeroed();
514        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
515            && winsize.ws_col > 0
516        {
517            winsize.ws_col as usize
518        } else {
519            80
520        }
521    }
522}
523
524#[cfg(not(unix))]
525fn get_terminal_width() -> usize {
526    80
527}
528
529fn strip_ansi_len(s: &str) -> usize {
530    let mut len = 0;
531    let mut in_escape = false;
532    for c in s.chars() {
533        if c == '\x1b' {
534            in_escape = true;
535        } else if in_escape {
536            if c == 'm' {
537                in_escape = false;
538            }
539        } else {
540            len += 1;
541        }
542    }
543    len
544}
545
546fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
547    if s.len() <= max_len {
548        Cow::Borrowed(s)
549    } else {
550        Cow::Owned(format!("{}", &s[..max_len.saturating_sub(1)]))
551    }
552}
553
554fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
555    let mut visible_len = 0;
556    let mut in_escape = false;
557    let mut last_byte_index = 0;
558    for (byte_index, c) in s.char_indices() {
559        if c == '\x1b' {
560            in_escape = true;
561        } else if in_escape {
562            if c == 'm' {
563                in_escape = false;
564            }
565        } else {
566            if visible_len >= max_visible_len {
567                return &s[..last_byte_index];
568            }
569            visible_len += 1;
570        }
571        last_byte_index = byte_index + c.len_utf8();
572    }
573    s
574}
575
576fn format_duration(duration: Duration) -> String {
577    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
578
579    let millis = duration.as_millis() as f32;
580    if millis < 1000.0 {
581        format!("{}ms", millis)
582    } else if millis < MINUTE_IN_MILLIS {
583        format!("{:.1}s", millis / 1_000.0)
584    } else {
585        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
586    }
587}