progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use log::{Level, Log, Metadata, Record};
 10
 11pub struct Progress {
 12    inner: Mutex<ProgressInner>,
 13}
 14
 15struct ProgressInner {
 16    completed: Vec<CompletedTask>,
 17    in_progress: HashMap<String, InProgressTask>,
 18    is_tty: bool,
 19    terminal_width: usize,
 20    max_example_name_len: usize,
 21    status_lines_displayed: usize,
 22    total_steps: usize,
 23    failed_examples: usize,
 24    last_line_is_logging: bool,
 25    ticker: Option<std::thread::JoinHandle<()>>,
 26}
 27
 28#[derive(Clone)]
 29struct InProgressTask {
 30    step: Step,
 31    started_at: Instant,
 32    substatus: Option<String>,
 33    info: Option<(String, InfoStyle)>,
 34}
 35
 36struct CompletedTask {
 37    step: Step,
 38    example_name: String,
 39    duration: Duration,
 40    info: Option<(String, InfoStyle)>,
 41}
 42
 43#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 44pub enum Step {
 45    LoadProject,
 46    Context,
 47    FormatPrompt,
 48    Predict,
 49    Score,
 50    Synthesize,
 51    PullExamples,
 52}
 53
 54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 55pub enum InfoStyle {
 56    Normal,
 57    Warning,
 58}
 59
 60impl Step {
 61    pub fn label(&self) -> &'static str {
 62        match self {
 63            Step::LoadProject => "Load",
 64            Step::Context => "Context",
 65            Step::FormatPrompt => "Format",
 66            Step::Predict => "Predict",
 67            Step::Score => "Score",
 68            Step::Synthesize => "Synthesize",
 69            Step::PullExamples => "Pull",
 70        }
 71    }
 72
 73    fn color_code(&self) -> &'static str {
 74        match self {
 75            Step::LoadProject => "\x1b[33m",
 76            Step::Context => "\x1b[35m",
 77            Step::FormatPrompt => "\x1b[34m",
 78            Step::Predict => "\x1b[32m",
 79            Step::Score => "\x1b[31m",
 80            Step::Synthesize => "\x1b[36m",
 81            Step::PullExamples => "\x1b[36m",
 82        }
 83    }
 84}
 85
 86static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 87static LOGGER: ProgressLogger = ProgressLogger;
 88
 89const MARGIN: usize = 4;
 90const MAX_STATUS_LINES: usize = 10;
 91const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 92
 93impl Progress {
 94    /// Returns the global Progress instance, initializing it if necessary.
 95    pub fn global() -> Arc<Progress> {
 96        GLOBAL
 97            .get_or_init(|| {
 98                let progress = Arc::new(Self {
 99                    inner: Mutex::new(ProgressInner {
100                        completed: Vec::new(),
101                        in_progress: HashMap::new(),
102                        is_tty: std::io::stderr().is_terminal(),
103                        terminal_width: get_terminal_width(),
104                        max_example_name_len: 0,
105                        status_lines_displayed: 0,
106                        total_steps: 0,
107                        failed_examples: 0,
108                        last_line_is_logging: false,
109                        ticker: None,
110                    }),
111                });
112                let _ = log::set_logger(&LOGGER);
113                log::set_max_level(log::LevelFilter::Error);
114                progress
115            })
116            .clone()
117    }
118
119    pub fn set_total_steps(&self, total: usize) {
120        let mut inner = self.inner.lock().unwrap();
121        inner.total_steps = total;
122    }
123
124    pub fn increment_failed(&self) {
125        let mut inner = self.inner.lock().unwrap();
126        inner.failed_examples += 1;
127    }
128
129    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
130    /// This should be used for any output that needs to appear above the status lines.
131    fn log(&self, message: &str) {
132        let mut inner = self.inner.lock().unwrap();
133        Self::clear_status_lines(&mut inner);
134
135        if !inner.last_line_is_logging {
136            let reset = "\x1b[0m";
137            let dim = "\x1b[2m";
138            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
139            eprintln!("{dim}{divider}{reset}");
140            inner.last_line_is_logging = true;
141        }
142
143        eprintln!("{}", message);
144    }
145
146    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
147        let mut inner = self.inner.lock().unwrap();
148
149        Self::clear_status_lines(&mut inner);
150
151        let max_name_width = inner
152            .terminal_width
153            .saturating_sub(MARGIN * 2)
154            .saturating_div(3)
155            .max(1);
156        inner.max_example_name_len = inner
157            .max_example_name_len
158            .max(example_name.len().min(max_name_width));
159        inner.in_progress.insert(
160            example_name.to_string(),
161            InProgressTask {
162                step,
163                started_at: Instant::now(),
164                substatus: None,
165                info: None,
166            },
167        );
168
169        if inner.is_tty && inner.ticker.is_none() {
170            let progress = self.clone();
171            inner.ticker = Some(std::thread::spawn(move || {
172                loop {
173                    std::thread::sleep(STATUS_TICK_INTERVAL);
174
175                    let mut inner = progress.inner.lock().unwrap();
176                    if inner.in_progress.is_empty() {
177                        break;
178                    }
179
180                    Progress::clear_status_lines(&mut inner);
181                    Progress::print_status_lines(&mut inner);
182                }
183            }));
184        }
185
186        Self::print_status_lines(&mut inner);
187
188        StepProgress {
189            progress: self.clone(),
190            step,
191            example_name: example_name.to_string(),
192        }
193    }
194
195    fn finish(&self, step: Step, example_name: &str) {
196        let mut inner = self.inner.lock().unwrap();
197
198        let Some(task) = inner.in_progress.remove(example_name) else {
199            return;
200        };
201
202        if task.step == step {
203            inner.completed.push(CompletedTask {
204                step: task.step,
205                example_name: example_name.to_string(),
206                duration: task.started_at.elapsed(),
207                info: task.info,
208            });
209
210            Self::clear_status_lines(&mut inner);
211            Self::print_logging_closing_divider(&mut inner);
212            if let Some(last_completed) = inner.completed.last() {
213                Self::print_completed(&inner, last_completed);
214            }
215            Self::print_status_lines(&mut inner);
216        } else {
217            inner.in_progress.insert(example_name.to_string(), task);
218        }
219    }
220
221    fn print_logging_closing_divider(inner: &mut ProgressInner) {
222        if inner.last_line_is_logging {
223            let reset = "\x1b[0m";
224            let dim = "\x1b[2m";
225            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
226            eprintln!("{dim}{divider}{reset}");
227            inner.last_line_is_logging = false;
228        }
229    }
230
231    fn clear_status_lines(inner: &mut ProgressInner) {
232        if inner.is_tty && inner.status_lines_displayed > 0 {
233            // Move up and clear each line we previously displayed
234            for _ in 0..inner.status_lines_displayed {
235                eprint!("\x1b[A\x1b[K");
236            }
237            let _ = std::io::stderr().flush();
238            inner.status_lines_displayed = 0;
239        }
240    }
241
242    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
243        let duration = format_duration(task.duration);
244        let name_width = inner.max_example_name_len;
245        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
246
247        if inner.is_tty {
248            let reset = "\x1b[0m";
249            let bold = "\x1b[1m";
250            let dim = "\x1b[2m";
251
252            let yellow = "\x1b[33m";
253            let info_part = task
254                .info
255                .as_ref()
256                .map(|(s, style)| {
257                    if *style == InfoStyle::Warning {
258                        format!("{yellow}{s}{reset}")
259                    } else {
260                        s.to_string()
261                    }
262                })
263                .unwrap_or_default();
264
265            let prefix = format!(
266                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
267                color = task.step.color_code(),
268                label = task.step.label(),
269                name = truncated_name,
270            );
271
272            let duration_with_margin = format!("{duration} ");
273            let padding_needed = inner
274                .terminal_width
275                .saturating_sub(MARGIN)
276                .saturating_sub(duration_with_margin.len())
277                .saturating_sub(strip_ansi_len(&prefix));
278            let padding = " ".repeat(padding_needed);
279
280            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
281        } else {
282            let info_part = task
283                .info
284                .as_ref()
285                .map(|(s, _)| format!(" | {}", s))
286                .unwrap_or_default();
287
288            eprintln!(
289                "{label:>12} {name:<name_width$}{info_part} {duration}",
290                label = task.step.label(),
291                name = truncate_with_ellipsis(&task.example_name, name_width),
292            );
293        }
294    }
295
296    fn print_status_lines(inner: &mut ProgressInner) {
297        if !inner.is_tty || inner.in_progress.is_empty() {
298            inner.status_lines_displayed = 0;
299            return;
300        }
301
302        let reset = "\x1b[0m";
303        let bold = "\x1b[1m";
304        let dim = "\x1b[2m";
305
306        // Build the done/in-progress/total label
307        let done_count = inner.completed.len();
308        let in_progress_count = inner.in_progress.len();
309        let failed_count = inner.failed_examples;
310
311        let failed_label = if failed_count > 0 {
312            format!(" {} failed ", failed_count)
313        } else {
314            String::new()
315        };
316
317        let range_label = format!(
318            " {}/{}/{} ",
319            done_count, in_progress_count, inner.total_steps
320        );
321
322        // Print a divider line with failed count on left, range label on right
323        let failed_visible_len = strip_ansi_len(&failed_label);
324        let range_visible_len = range_label.len();
325        let middle_divider_len = inner
326            .terminal_width
327            .saturating_sub(MARGIN * 2)
328            .saturating_sub(failed_visible_len)
329            .saturating_sub(range_visible_len);
330        let left_divider = "".repeat(MARGIN);
331        let middle_divider = "".repeat(middle_divider_len);
332        let right_divider = "".repeat(MARGIN);
333        eprintln!(
334            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
335        );
336
337        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
338        tasks.sort_by_key(|(name, _)| *name);
339
340        let total_tasks = tasks.len();
341        let mut lines_printed = 0;
342
343        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
344            let elapsed = format_duration(task.started_at.elapsed());
345            let substatus_part = task
346                .substatus
347                .as_ref()
348                .map(|s| truncate_with_ellipsis(s, 30))
349                .unwrap_or_default();
350
351            let step_label = task.step.label();
352            let step_color = task.step.color_code();
353            let name_width = inner.max_example_name_len;
354            let truncated_name = truncate_with_ellipsis(name, name_width);
355
356            let prefix = format!(
357                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
358                name = truncated_name,
359            );
360
361            let duration_with_margin = format!("{elapsed} ");
362            let padding_needed = inner
363                .terminal_width
364                .saturating_sub(MARGIN)
365                .saturating_sub(duration_with_margin.len())
366                .saturating_sub(strip_ansi_len(&prefix));
367            let padding = " ".repeat(padding_needed);
368
369            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
370            lines_printed += 1;
371        }
372
373        // Show "+N more" on its own line if there are more tasks
374        if total_tasks > MAX_STATUS_LINES {
375            let remaining = total_tasks - MAX_STATUS_LINES;
376            eprintln!("{:>12} +{remaining} more", "");
377            lines_printed += 1;
378        }
379
380        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
381        let _ = std::io::stderr().flush();
382    }
383
384    pub fn finalize(&self) {
385        let ticker = {
386            let mut inner = self.inner.lock().unwrap();
387            inner.ticker.take()
388        };
389
390        if let Some(ticker) = ticker {
391            let _ = ticker.join();
392        }
393
394        let mut inner = self.inner.lock().unwrap();
395        Self::clear_status_lines(&mut inner);
396
397        // Print summary if there were failures
398        if inner.failed_examples > 0 {
399            let total_processed = inner.completed.len() + inner.failed_examples;
400            let percentage = if total_processed > 0 {
401                inner.failed_examples as f64 / total_processed as f64 * 100.0
402            } else {
403                0.0
404            };
405            eprintln!(
406                "\n{} of {} examples failed ({:.1}%)",
407                inner.failed_examples, total_processed, percentage
408            );
409        }
410    }
411}
412
413pub struct StepProgress {
414    progress: Arc<Progress>,
415    step: Step,
416    example_name: String,
417}
418
419impl StepProgress {
420    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
421        let mut inner = self.progress.inner.lock().unwrap();
422        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
423            task.substatus = Some(substatus.into().into_owned());
424            Progress::clear_status_lines(&mut inner);
425            Progress::print_status_lines(&mut inner);
426        }
427    }
428
429    pub fn clear_substatus(&self) {
430        let mut inner = self.progress.inner.lock().unwrap();
431        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
432            task.substatus = None;
433            Progress::clear_status_lines(&mut inner);
434            Progress::print_status_lines(&mut inner);
435        }
436    }
437
438    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
439        let mut inner = self.progress.inner.lock().unwrap();
440        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
441            task.info = Some((info.into(), style));
442        }
443    }
444}
445
446impl Drop for StepProgress {
447    fn drop(&mut self) {
448        self.progress.finish(self.step, &self.example_name);
449    }
450}
451
452struct ProgressLogger;
453
454impl Log for ProgressLogger {
455    fn enabled(&self, metadata: &Metadata) -> bool {
456        metadata.level() <= Level::Info
457    }
458
459    fn log(&self, record: &Record) {
460        if !self.enabled(record.metadata()) {
461            return;
462        }
463
464        let level_color = match record.level() {
465            Level::Error => "\x1b[31m",
466            Level::Warn => "\x1b[33m",
467            Level::Info => "\x1b[32m",
468            Level::Debug => "\x1b[34m",
469            Level::Trace => "\x1b[35m",
470        };
471        let reset = "\x1b[0m";
472        let bold = "\x1b[1m";
473
474        let level_label = match record.level() {
475            Level::Error => "Error",
476            Level::Warn => "Warn",
477            Level::Info => "Info",
478            Level::Debug => "Debug",
479            Level::Trace => "Trace",
480        };
481
482        let message = format!(
483            "{bold}{level_color}{level_label:>12}{reset} {}",
484            record.args()
485        );
486
487        if let Some(progress) = GLOBAL.get() {
488            progress.log(&message);
489        } else {
490            eprintln!("{}", message);
491        }
492    }
493
494    fn flush(&self) {
495        let _ = std::io::stderr().flush();
496    }
497}
498
499#[cfg(unix)]
500fn get_terminal_width() -> usize {
501    unsafe {
502        let mut winsize: libc::winsize = std::mem::zeroed();
503        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
504            && winsize.ws_col > 0
505        {
506            winsize.ws_col as usize
507        } else {
508            80
509        }
510    }
511}
512
513#[cfg(not(unix))]
514fn get_terminal_width() -> usize {
515    80
516}
517
518fn strip_ansi_len(s: &str) -> usize {
519    let mut len = 0;
520    let mut in_escape = false;
521    for c in s.chars() {
522        if c == '\x1b' {
523            in_escape = true;
524        } else if in_escape {
525            if c == 'm' {
526                in_escape = false;
527            }
528        } else {
529            len += 1;
530        }
531    }
532    len
533}
534
535fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
536    if s.len() <= max_len {
537        s.to_string()
538    } else {
539        format!("{}", &s[..max_len.saturating_sub(1)])
540    }
541}
542
543fn format_duration(duration: Duration) -> String {
544    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
545
546    let millis = duration.as_millis() as f32;
547    if millis < 1000.0 {
548        format!("{}ms", millis)
549    } else if millis < MINUTE_IN_MILLIS {
550        format!("{:.1}s", millis / 1_000.0)
551    } else {
552        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
553    }
554}