1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use crate::paths::RUN_DIR;
10
11use log::{Level, Log, Metadata, Record};
12
13pub struct Progress {
14 inner: Mutex<ProgressInner>,
15}
16
17struct ProgressInner {
18 completed: Vec<CompletedTask>,
19 in_progress: HashMap<String, InProgressTask>,
20 is_tty: bool,
21 terminal_width: usize,
22 max_example_name_len: usize,
23 status_lines_displayed: usize,
24 total_examples: usize,
25 completed_examples: usize,
26 failed_examples: usize,
27 last_line_is_logging: bool,
28 ticker: Option<std::thread::JoinHandle<()>>,
29}
30
31#[derive(Clone)]
32struct InProgressTask {
33 step: Step,
34 started_at: Instant,
35 substatus: Option<String>,
36 info: Option<(String, InfoStyle)>,
37}
38
39struct CompletedTask {
40 step: Step,
41 example_name: String,
42 duration: Duration,
43 info: Option<(String, InfoStyle)>,
44}
45
46#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
47pub enum Step {
48 LoadProject,
49 Context,
50 FormatPrompt,
51 Predict,
52 Score,
53 Synthesize,
54 PullExamples,
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
58pub enum InfoStyle {
59 Normal,
60 Warning,
61}
62
63impl Step {
64 pub fn label(&self) -> &'static str {
65 match self {
66 Step::LoadProject => "Load",
67 Step::Context => "Context",
68 Step::FormatPrompt => "Format",
69 Step::Predict => "Predict",
70 Step::Score => "Score",
71 Step::Synthesize => "Synthesize",
72 Step::PullExamples => "Pull",
73 }
74 }
75
76 fn color_code(&self) -> &'static str {
77 match self {
78 Step::LoadProject => "\x1b[33m",
79 Step::Context => "\x1b[35m",
80 Step::FormatPrompt => "\x1b[34m",
81 Step::Predict => "\x1b[32m",
82 Step::Score => "\x1b[31m",
83 Step::Synthesize => "\x1b[36m",
84 Step::PullExamples => "\x1b[36m",
85 }
86 }
87}
88
89static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
90static LOGGER: ProgressLogger = ProgressLogger;
91
92const MARGIN: usize = 4;
93const MAX_STATUS_LINES: usize = 10;
94const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
95
96impl Progress {
97 /// Returns the global Progress instance, initializing it if necessary.
98 pub fn global() -> Arc<Progress> {
99 GLOBAL
100 .get_or_init(|| {
101 let progress = Arc::new(Self {
102 inner: Mutex::new(ProgressInner {
103 completed: Vec::new(),
104 in_progress: HashMap::new(),
105 is_tty: std::env::var("COLOR").is_ok()
106 || (std::env::var("NO_COLOR").is_err()
107 && std::io::stderr().is_terminal()),
108 terminal_width: get_terminal_width(),
109 max_example_name_len: 0,
110 status_lines_displayed: 0,
111 total_examples: 0,
112 completed_examples: 0,
113 failed_examples: 0,
114 last_line_is_logging: false,
115 ticker: None,
116 }),
117 });
118 let _ = log::set_logger(&LOGGER);
119 log::set_max_level(log::LevelFilter::Info);
120 progress
121 })
122 .clone()
123 }
124
125 pub fn start_group(self: &Arc<Self>, example_name: &str) -> ExampleProgress {
126 ExampleProgress {
127 progress: self.clone(),
128 example_name: example_name.to_string(),
129 }
130 }
131
132 fn increment_completed(&self) {
133 let mut inner = self.inner.lock().unwrap();
134 inner.completed_examples += 1;
135 }
136
137 pub fn set_total_examples(&self, total: usize) {
138 let mut inner = self.inner.lock().unwrap();
139 inner.total_examples = total;
140 }
141
142 pub fn set_max_example_name_len(&self, example_names: impl Iterator<Item = impl AsRef<str>>) {
143 let mut inner = self.inner.lock().unwrap();
144 let max_name_width = inner
145 .terminal_width
146 .saturating_sub(MARGIN * 2)
147 .saturating_div(3)
148 .max(1);
149 inner.max_example_name_len = example_names
150 .map(|name| name.as_ref().len().min(max_name_width))
151 .max()
152 .unwrap_or(0);
153 }
154
155 pub fn increment_failed(&self) {
156 let mut inner = self.inner.lock().unwrap();
157 inner.failed_examples += 1;
158 }
159
160 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
161 /// This should be used for any output that needs to appear above the status lines.
162 fn log(&self, message: &str) {
163 let mut inner = self.inner.lock().unwrap();
164 Self::clear_status_lines(&mut inner);
165
166 if !inner.last_line_is_logging {
167 let reset = "\x1b[0m";
168 let dim = "\x1b[2m";
169 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
170 eprintln!("{dim}{divider}{reset}");
171 inner.last_line_is_logging = true;
172 }
173
174 let max_width = inner.terminal_width.saturating_sub(MARGIN);
175 for line in message.lines() {
176 let truncated = truncate_to_visible_width(line, max_width);
177 if truncated.len() < line.len() {
178 eprintln!("{}…", truncated);
179 } else {
180 eprintln!("{}", truncated);
181 }
182 }
183
184 Self::print_status_lines(&mut inner);
185 }
186
187 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
188 let mut inner = self.inner.lock().unwrap();
189
190 Self::clear_status_lines(&mut inner);
191
192 // Update max_example_name_len if not already set via set_max_example_name_len
193 if inner.max_example_name_len == 0 {
194 let max_name_width = inner
195 .terminal_width
196 .saturating_sub(MARGIN * 2)
197 .saturating_div(3)
198 .max(1);
199 inner.max_example_name_len = example_name.len().min(max_name_width);
200 }
201 inner.in_progress.insert(
202 example_name.to_string(),
203 InProgressTask {
204 step,
205 started_at: Instant::now(),
206 substatus: None,
207 info: None,
208 },
209 );
210
211 if inner.is_tty && inner.ticker.is_none() {
212 let progress = self.clone();
213 inner.ticker = Some(std::thread::spawn(move || {
214 loop {
215 std::thread::sleep(STATUS_TICK_INTERVAL);
216
217 let mut inner = progress.inner.lock().unwrap();
218 if inner.in_progress.is_empty() {
219 break;
220 }
221
222 Progress::clear_status_lines(&mut inner);
223 Progress::print_status_lines(&mut inner);
224 }
225 }));
226 }
227
228 Self::print_status_lines(&mut inner);
229
230 StepProgress {
231 progress: self.clone(),
232 step,
233 example_name: example_name.to_string(),
234 }
235 }
236
237 fn finish(&self, step: Step, example_name: &str) {
238 let mut inner = self.inner.lock().unwrap();
239
240 let Some(task) = inner.in_progress.remove(example_name) else {
241 return;
242 };
243
244 if task.step == step {
245 let duration = task.started_at.elapsed();
246
247 // Skip logging for tasks that complete quickly (under 500ms)
248 let should_print = duration >= Duration::from_millis(500);
249
250 inner.completed.push(CompletedTask {
251 step: task.step,
252 example_name: example_name.to_string(),
253 duration,
254 info: task.info,
255 });
256
257 Self::clear_status_lines(&mut inner);
258 if should_print {
259 Self::print_logging_closing_divider(&mut inner);
260 if let Some(last_completed) = inner.completed.last() {
261 Self::print_completed(&inner, last_completed);
262 }
263 }
264 Self::print_status_lines(&mut inner);
265 } else {
266 inner.in_progress.insert(example_name.to_string(), task);
267 }
268 }
269
270 fn print_logging_closing_divider(inner: &mut ProgressInner) {
271 if inner.last_line_is_logging {
272 let reset = "\x1b[0m";
273 let dim = "\x1b[2m";
274 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
275 eprintln!("{dim}{divider}{reset}");
276 inner.last_line_is_logging = false;
277 }
278 }
279
280 fn clear_status_lines(inner: &mut ProgressInner) {
281 if inner.is_tty && inner.status_lines_displayed > 0 {
282 // Move up and clear each line we previously displayed
283 for _ in 0..inner.status_lines_displayed {
284 eprint!("\x1b[A\x1b[K");
285 }
286 inner.status_lines_displayed = 0;
287 }
288 }
289
290 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
291 let duration = format_duration(task.duration);
292 let name_width = inner.max_example_name_len;
293 let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
294
295 if inner.is_tty {
296 let reset = "\x1b[0m";
297 let bold = "\x1b[1m";
298 let dim = "\x1b[2m";
299
300 let yellow = "\x1b[33m";
301 let info_part = task
302 .info
303 .as_ref()
304 .map(|(s, style)| {
305 if *style == InfoStyle::Warning {
306 format!("{yellow}{s}{reset}")
307 } else {
308 s.to_string()
309 }
310 })
311 .unwrap_or_default();
312
313 let prefix = format!(
314 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
315 color = task.step.color_code(),
316 label = task.step.label(),
317 name = truncated_name,
318 );
319
320 let duration_with_margin = format!("{duration} ");
321 let padding_needed = inner
322 .terminal_width
323 .saturating_sub(MARGIN)
324 .saturating_sub(duration_with_margin.len())
325 .saturating_sub(strip_ansi_len(&prefix));
326 let padding = " ".repeat(padding_needed);
327
328 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
329 } else {
330 let info_part = task
331 .info
332 .as_ref()
333 .map(|(s, _)| format!(" | {}", s))
334 .unwrap_or_default();
335
336 eprintln!(
337 "{label:>12} {name:<name_width$}{info_part} {duration}",
338 label = task.step.label(),
339 name = truncate_with_ellipsis(&task.example_name, name_width),
340 );
341 }
342 }
343
344 fn print_status_lines(inner: &mut ProgressInner) {
345 if !inner.is_tty || inner.in_progress.is_empty() {
346 inner.status_lines_displayed = 0;
347 return;
348 }
349
350 let reset = "\x1b[0m";
351 let bold = "\x1b[1m";
352 let dim = "\x1b[2m";
353
354 // Build the done/in-progress/total label
355 let done_count = inner.completed_examples;
356 let in_progress_count = inner.in_progress.len();
357 let failed_count = inner.failed_examples;
358
359 let failed_label = if failed_count > 0 {
360 format!(" {} failed ", failed_count)
361 } else {
362 String::new()
363 };
364
365 let range_label = format!(
366 " {}/{}/{} ",
367 done_count, in_progress_count, inner.total_examples
368 );
369
370 // Print a divider line with failed count on left, range label on right
371 let failed_visible_len = strip_ansi_len(&failed_label);
372 let range_visible_len = range_label.len();
373 let middle_divider_len = inner
374 .terminal_width
375 .saturating_sub(MARGIN * 2)
376 .saturating_sub(failed_visible_len)
377 .saturating_sub(range_visible_len);
378 let left_divider = "─".repeat(MARGIN);
379 let middle_divider = "─".repeat(middle_divider_len);
380 let right_divider = "─".repeat(MARGIN);
381 eprintln!(
382 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
383 );
384
385 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
386 tasks.sort_by_key(|(name, _)| *name);
387
388 let total_tasks = tasks.len();
389 let mut lines_printed = 0;
390
391 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
392 let elapsed = format_duration(task.started_at.elapsed());
393 let substatus_part = task
394 .substatus
395 .as_ref()
396 .map(|s| truncate_with_ellipsis(s, 30))
397 .unwrap_or_default();
398
399 let step_label = task.step.label();
400 let step_color = task.step.color_code();
401 let name_width = inner.max_example_name_len;
402 let truncated_name = truncate_with_ellipsis(name, name_width);
403
404 let prefix = format!(
405 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
406 name = truncated_name,
407 );
408
409 let duration_with_margin = format!("{elapsed} ");
410 let padding_needed = inner
411 .terminal_width
412 .saturating_sub(MARGIN)
413 .saturating_sub(duration_with_margin.len())
414 .saturating_sub(strip_ansi_len(&prefix));
415 let padding = " ".repeat(padding_needed);
416
417 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
418 lines_printed += 1;
419 }
420
421 // Show "+N more" on its own line if there are more tasks
422 if total_tasks > MAX_STATUS_LINES {
423 let remaining = total_tasks - MAX_STATUS_LINES;
424 eprintln!("{:>12} +{remaining} more", "");
425 lines_printed += 1;
426 }
427
428 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
429 let _ = std::io::stderr().flush();
430 }
431
432 pub fn finalize(&self) {
433 let ticker = {
434 let mut inner = self.inner.lock().unwrap();
435 inner.ticker.take()
436 };
437
438 if let Some(ticker) = ticker {
439 let _ = ticker.join();
440 }
441
442 let mut inner = self.inner.lock().unwrap();
443 Self::clear_status_lines(&mut inner);
444
445 // Print summary if there were failures
446 if inner.failed_examples > 0 {
447 let total_examples = inner.total_examples;
448 let percentage = if total_examples > 0 {
449 inner.failed_examples as f64 / total_examples as f64 * 100.0
450 } else {
451 0.0
452 };
453 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
454 eprintln!(
455 "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
456 inner.failed_examples,
457 total_examples,
458 percentage,
459 failed_jsonl_path.display()
460 );
461 }
462 }
463}
464
465pub struct ExampleProgress {
466 progress: Arc<Progress>,
467 example_name: String,
468}
469
470impl ExampleProgress {
471 pub fn start(&self, step: Step) -> StepProgress {
472 self.progress.start(step, &self.example_name)
473 }
474}
475
476impl Drop for ExampleProgress {
477 fn drop(&mut self) {
478 self.progress.increment_completed();
479 }
480}
481
482pub struct StepProgress {
483 progress: Arc<Progress>,
484 step: Step,
485 example_name: String,
486}
487
488impl StepProgress {
489 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
490 let mut inner = self.progress.inner.lock().unwrap();
491 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
492 task.substatus = Some(substatus.into().into_owned());
493 Progress::clear_status_lines(&mut inner);
494 Progress::print_status_lines(&mut inner);
495 }
496 }
497
498 pub fn clear_substatus(&self) {
499 let mut inner = self.progress.inner.lock().unwrap();
500 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
501 task.substatus = None;
502 Progress::clear_status_lines(&mut inner);
503 Progress::print_status_lines(&mut inner);
504 }
505 }
506
507 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
508 let mut inner = self.progress.inner.lock().unwrap();
509 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
510 task.info = Some((info.into(), style));
511 }
512 }
513}
514
515impl Drop for StepProgress {
516 fn drop(&mut self) {
517 self.progress.finish(self.step, &self.example_name);
518 }
519}
520
521struct ProgressLogger;
522
523impl Log for ProgressLogger {
524 fn enabled(&self, metadata: &Metadata) -> bool {
525 metadata.level() <= Level::Info
526 }
527
528 fn log(&self, record: &Record) {
529 if !self.enabled(record.metadata()) {
530 return;
531 }
532
533 let level_color = match record.level() {
534 Level::Error => "\x1b[31m",
535 Level::Warn => "\x1b[33m",
536 Level::Info => "\x1b[32m",
537 Level::Debug => "\x1b[34m",
538 Level::Trace => "\x1b[35m",
539 };
540 let reset = "\x1b[0m";
541 let bold = "\x1b[1m";
542
543 let level_label = match record.level() {
544 Level::Error => "Error",
545 Level::Warn => "Warn",
546 Level::Info => "Info",
547 Level::Debug => "Debug",
548 Level::Trace => "Trace",
549 };
550
551 let message = format!(
552 "{bold}{level_color}{level_label:>12}{reset} {}",
553 record.args()
554 );
555
556 if let Some(progress) = GLOBAL.get() {
557 progress.log(&message);
558 } else {
559 eprintln!("{}", message);
560 }
561 }
562
563 fn flush(&self) {
564 let _ = std::io::stderr().flush();
565 }
566}
567
568#[cfg(unix)]
569fn get_terminal_width() -> usize {
570 unsafe {
571 let mut winsize: libc::winsize = std::mem::zeroed();
572 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
573 && winsize.ws_col > 0
574 {
575 winsize.ws_col as usize
576 } else {
577 80
578 }
579 }
580}
581
582#[cfg(not(unix))]
583fn get_terminal_width() -> usize {
584 80
585}
586
587fn strip_ansi_len(s: &str) -> usize {
588 let mut len = 0;
589 let mut in_escape = false;
590 for c in s.chars() {
591 if c == '\x1b' {
592 in_escape = true;
593 } else if in_escape {
594 if c == 'm' {
595 in_escape = false;
596 }
597 } else {
598 len += 1;
599 }
600 }
601 len
602}
603
604fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
605 if s.len() <= max_len {
606 Cow::Borrowed(s)
607 } else {
608 Cow::Owned(format!("{}…", &s[..max_len.saturating_sub(1)]))
609 }
610}
611
612fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
613 let mut visible_len = 0;
614 let mut in_escape = false;
615 let mut last_byte_index = 0;
616 for (byte_index, c) in s.char_indices() {
617 if c == '\x1b' {
618 in_escape = true;
619 } else if in_escape {
620 if c == 'm' {
621 in_escape = false;
622 }
623 } else {
624 if visible_len >= max_visible_len {
625 return &s[..last_byte_index];
626 }
627 visible_len += 1;
628 }
629 last_byte_index = byte_index + c.len_utf8();
630 }
631 s
632}
633
634fn format_duration(duration: Duration) -> String {
635 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
636
637 let millis = duration.as_millis() as f32;
638 if millis < 1000.0 {
639 format!("{}ms", millis)
640 } else if millis < MINUTE_IN_MILLIS {
641 format!("{:.1}s", millis / 1_000.0)
642 } else {
643 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
644 }
645}