1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use crate::paths::RUN_DIR;
10
11use log::{Level, Log, Metadata, Record};
12
13pub struct Progress {
14 inner: Mutex<ProgressInner>,
15}
16
17struct ProgressInner {
18 completed: Vec<CompletedTask>,
19 in_progress: HashMap<String, InProgressTask>,
20 is_tty: bool,
21 terminal_width: usize,
22 max_example_name_len: usize,
23 status_lines_displayed: usize,
24 total_examples: usize,
25 completed_examples: usize,
26 failed_examples: usize,
27 last_line_is_logging: bool,
28 ticker: Option<std::thread::JoinHandle<()>>,
29}
30
31#[derive(Clone)]
32struct InProgressTask {
33 step: Step,
34 started_at: Instant,
35 substatus: Option<String>,
36 info: Option<(String, InfoStyle)>,
37}
38
39struct CompletedTask {
40 step: Step,
41 example_name: String,
42 duration: Duration,
43 info: Option<(String, InfoStyle)>,
44}
45
46#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
47pub enum Step {
48 LoadProject,
49 Context,
50 FormatPrompt,
51 Predict,
52 Score,
53 Synthesize,
54 PullExamples,
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
58pub enum InfoStyle {
59 Normal,
60 Warning,
61}
62
63impl Step {
64 pub fn label(&self) -> &'static str {
65 match self {
66 Step::LoadProject => "Load",
67 Step::Context => "Context",
68 Step::FormatPrompt => "Format",
69 Step::Predict => "Predict",
70 Step::Score => "Score",
71 Step::Synthesize => "Synthesize",
72 Step::PullExamples => "Pull",
73 }
74 }
75
76 fn color_code(&self) -> &'static str {
77 match self {
78 Step::LoadProject => "\x1b[33m",
79 Step::Context => "\x1b[35m",
80 Step::FormatPrompt => "\x1b[34m",
81 Step::Predict => "\x1b[32m",
82 Step::Score => "\x1b[31m",
83 Step::Synthesize => "\x1b[36m",
84 Step::PullExamples => "\x1b[36m",
85 }
86 }
87}
88
89static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
90static LOGGER: ProgressLogger = ProgressLogger;
91
92const MARGIN: usize = 4;
93const MAX_STATUS_LINES: usize = 10;
94const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
95
96impl Progress {
97 /// Returns the global Progress instance, initializing it if necessary.
98 pub fn global() -> Arc<Progress> {
99 GLOBAL
100 .get_or_init(|| {
101 let progress = Arc::new(Self {
102 inner: Mutex::new(ProgressInner {
103 completed: Vec::new(),
104 in_progress: HashMap::new(),
105 is_tty: std::env::var("COLOR").is_ok()
106 || (std::env::var("NO_COLOR").is_err()
107 && std::io::stderr().is_terminal()),
108 terminal_width: get_terminal_width(),
109 max_example_name_len: 0,
110 status_lines_displayed: 0,
111 total_examples: 0,
112 completed_examples: 0,
113 failed_examples: 0,
114 last_line_is_logging: false,
115 ticker: None,
116 }),
117 });
118 let _ = log::set_logger(&LOGGER);
119 log::set_max_level(log::LevelFilter::Info);
120 progress
121 })
122 .clone()
123 }
124
125 pub fn start_group(self: &Arc<Self>, example_name: &str) -> ExampleProgress {
126 ExampleProgress {
127 progress: self.clone(),
128 example_name: example_name.to_string(),
129 }
130 }
131
132 fn increment_completed(&self) {
133 let mut inner = self.inner.lock().unwrap();
134 inner.completed_examples += 1;
135 }
136
137 pub fn set_total_examples(&self, total: usize) {
138 let mut inner = self.inner.lock().unwrap();
139 inner.total_examples = total;
140 }
141
142 pub fn increment_failed(&self) {
143 let mut inner = self.inner.lock().unwrap();
144 inner.failed_examples += 1;
145 }
146
147 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
148 /// This should be used for any output that needs to appear above the status lines.
149 fn log(&self, message: &str) {
150 let mut inner = self.inner.lock().unwrap();
151 Self::clear_status_lines(&mut inner);
152
153 if !inner.last_line_is_logging {
154 let reset = "\x1b[0m";
155 let dim = "\x1b[2m";
156 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
157 eprintln!("{dim}{divider}{reset}");
158 inner.last_line_is_logging = true;
159 }
160
161 let max_width = inner.terminal_width.saturating_sub(MARGIN);
162 for line in message.lines() {
163 let truncated = truncate_to_visible_width(line, max_width);
164 if truncated.len() < line.len() {
165 eprintln!("{}…", truncated);
166 } else {
167 eprintln!("{}", truncated);
168 }
169 }
170
171 Self::print_status_lines(&mut inner);
172 }
173
174 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
175 let mut inner = self.inner.lock().unwrap();
176
177 Self::clear_status_lines(&mut inner);
178
179 let max_name_width = inner
180 .terminal_width
181 .saturating_sub(MARGIN * 2)
182 .saturating_div(3)
183 .max(1);
184 inner.max_example_name_len = inner
185 .max_example_name_len
186 .max(example_name.len().min(max_name_width));
187 inner.in_progress.insert(
188 example_name.to_string(),
189 InProgressTask {
190 step,
191 started_at: Instant::now(),
192 substatus: None,
193 info: None,
194 },
195 );
196
197 if inner.is_tty && inner.ticker.is_none() {
198 let progress = self.clone();
199 inner.ticker = Some(std::thread::spawn(move || {
200 loop {
201 std::thread::sleep(STATUS_TICK_INTERVAL);
202
203 let mut inner = progress.inner.lock().unwrap();
204 if inner.in_progress.is_empty() {
205 break;
206 }
207
208 Progress::clear_status_lines(&mut inner);
209 Progress::print_status_lines(&mut inner);
210 }
211 }));
212 }
213
214 Self::print_status_lines(&mut inner);
215
216 StepProgress {
217 progress: self.clone(),
218 step,
219 example_name: example_name.to_string(),
220 }
221 }
222
223 fn finish(&self, step: Step, example_name: &str) {
224 let mut inner = self.inner.lock().unwrap();
225
226 let Some(task) = inner.in_progress.remove(example_name) else {
227 return;
228 };
229
230 if task.step == step {
231 inner.completed.push(CompletedTask {
232 step: task.step,
233 example_name: example_name.to_string(),
234 duration: task.started_at.elapsed(),
235 info: task.info,
236 });
237
238 Self::clear_status_lines(&mut inner);
239 Self::print_logging_closing_divider(&mut inner);
240 if let Some(last_completed) = inner.completed.last() {
241 Self::print_completed(&inner, last_completed);
242 }
243 Self::print_status_lines(&mut inner);
244 } else {
245 inner.in_progress.insert(example_name.to_string(), task);
246 }
247 }
248
249 fn print_logging_closing_divider(inner: &mut ProgressInner) {
250 if inner.last_line_is_logging {
251 let reset = "\x1b[0m";
252 let dim = "\x1b[2m";
253 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
254 eprintln!("{dim}{divider}{reset}");
255 inner.last_line_is_logging = false;
256 }
257 }
258
259 fn clear_status_lines(inner: &mut ProgressInner) {
260 if inner.is_tty && inner.status_lines_displayed > 0 {
261 // Move up and clear each line we previously displayed
262 for _ in 0..inner.status_lines_displayed {
263 eprint!("\x1b[A\x1b[K");
264 }
265 inner.status_lines_displayed = 0;
266 }
267 }
268
269 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
270 let duration = format_duration(task.duration);
271 let name_width = inner.max_example_name_len;
272 let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
273
274 if inner.is_tty {
275 let reset = "\x1b[0m";
276 let bold = "\x1b[1m";
277 let dim = "\x1b[2m";
278
279 let yellow = "\x1b[33m";
280 let info_part = task
281 .info
282 .as_ref()
283 .map(|(s, style)| {
284 if *style == InfoStyle::Warning {
285 format!("{yellow}{s}{reset}")
286 } else {
287 s.to_string()
288 }
289 })
290 .unwrap_or_default();
291
292 let prefix = format!(
293 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
294 color = task.step.color_code(),
295 label = task.step.label(),
296 name = truncated_name,
297 );
298
299 let duration_with_margin = format!("{duration} ");
300 let padding_needed = inner
301 .terminal_width
302 .saturating_sub(MARGIN)
303 .saturating_sub(duration_with_margin.len())
304 .saturating_sub(strip_ansi_len(&prefix));
305 let padding = " ".repeat(padding_needed);
306
307 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
308 } else {
309 let info_part = task
310 .info
311 .as_ref()
312 .map(|(s, _)| format!(" | {}", s))
313 .unwrap_or_default();
314
315 eprintln!(
316 "{label:>12} {name:<name_width$}{info_part} {duration}",
317 label = task.step.label(),
318 name = truncate_with_ellipsis(&task.example_name, name_width),
319 );
320 }
321 }
322
323 fn print_status_lines(inner: &mut ProgressInner) {
324 if !inner.is_tty || inner.in_progress.is_empty() {
325 inner.status_lines_displayed = 0;
326 return;
327 }
328
329 let reset = "\x1b[0m";
330 let bold = "\x1b[1m";
331 let dim = "\x1b[2m";
332
333 // Build the done/in-progress/total label
334 let done_count = inner.completed_examples;
335 let in_progress_count = inner.in_progress.len();
336 let failed_count = inner.failed_examples;
337
338 let failed_label = if failed_count > 0 {
339 format!(" {} failed ", failed_count)
340 } else {
341 String::new()
342 };
343
344 let range_label = format!(
345 " {}/{}/{} ",
346 done_count, in_progress_count, inner.total_examples
347 );
348
349 // Print a divider line with failed count on left, range label on right
350 let failed_visible_len = strip_ansi_len(&failed_label);
351 let range_visible_len = range_label.len();
352 let middle_divider_len = inner
353 .terminal_width
354 .saturating_sub(MARGIN * 2)
355 .saturating_sub(failed_visible_len)
356 .saturating_sub(range_visible_len);
357 let left_divider = "─".repeat(MARGIN);
358 let middle_divider = "─".repeat(middle_divider_len);
359 let right_divider = "─".repeat(MARGIN);
360 eprintln!(
361 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
362 );
363
364 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
365 tasks.sort_by_key(|(name, _)| *name);
366
367 let total_tasks = tasks.len();
368 let mut lines_printed = 0;
369
370 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
371 let elapsed = format_duration(task.started_at.elapsed());
372 let substatus_part = task
373 .substatus
374 .as_ref()
375 .map(|s| truncate_with_ellipsis(s, 30))
376 .unwrap_or_default();
377
378 let step_label = task.step.label();
379 let step_color = task.step.color_code();
380 let name_width = inner.max_example_name_len;
381 let truncated_name = truncate_with_ellipsis(name, name_width);
382
383 let prefix = format!(
384 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
385 name = truncated_name,
386 );
387
388 let duration_with_margin = format!("{elapsed} ");
389 let padding_needed = inner
390 .terminal_width
391 .saturating_sub(MARGIN)
392 .saturating_sub(duration_with_margin.len())
393 .saturating_sub(strip_ansi_len(&prefix));
394 let padding = " ".repeat(padding_needed);
395
396 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
397 lines_printed += 1;
398 }
399
400 // Show "+N more" on its own line if there are more tasks
401 if total_tasks > MAX_STATUS_LINES {
402 let remaining = total_tasks - MAX_STATUS_LINES;
403 eprintln!("{:>12} +{remaining} more", "");
404 lines_printed += 1;
405 }
406
407 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
408 let _ = std::io::stderr().flush();
409 }
410
411 pub fn finalize(&self) {
412 let ticker = {
413 let mut inner = self.inner.lock().unwrap();
414 inner.ticker.take()
415 };
416
417 if let Some(ticker) = ticker {
418 let _ = ticker.join();
419 }
420
421 let mut inner = self.inner.lock().unwrap();
422 Self::clear_status_lines(&mut inner);
423
424 // Print summary if there were failures
425 if inner.failed_examples > 0 {
426 let total_examples = inner.total_examples;
427 let percentage = if total_examples > 0 {
428 inner.failed_examples as f64 / total_examples as f64 * 100.0
429 } else {
430 0.0
431 };
432 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
433 eprintln!(
434 "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
435 inner.failed_examples,
436 total_examples,
437 percentage,
438 failed_jsonl_path.display()
439 );
440 }
441 }
442}
443
444pub struct ExampleProgress {
445 progress: Arc<Progress>,
446 example_name: String,
447}
448
449impl ExampleProgress {
450 pub fn start(&self, step: Step) -> StepProgress {
451 self.progress.start(step, &self.example_name)
452 }
453}
454
455impl Drop for ExampleProgress {
456 fn drop(&mut self) {
457 self.progress.increment_completed();
458 }
459}
460
461pub struct StepProgress {
462 progress: Arc<Progress>,
463 step: Step,
464 example_name: String,
465}
466
467impl StepProgress {
468 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
469 let mut inner = self.progress.inner.lock().unwrap();
470 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
471 task.substatus = Some(substatus.into().into_owned());
472 Progress::clear_status_lines(&mut inner);
473 Progress::print_status_lines(&mut inner);
474 }
475 }
476
477 pub fn clear_substatus(&self) {
478 let mut inner = self.progress.inner.lock().unwrap();
479 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
480 task.substatus = None;
481 Progress::clear_status_lines(&mut inner);
482 Progress::print_status_lines(&mut inner);
483 }
484 }
485
486 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
487 let mut inner = self.progress.inner.lock().unwrap();
488 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
489 task.info = Some((info.into(), style));
490 }
491 }
492}
493
494impl Drop for StepProgress {
495 fn drop(&mut self) {
496 self.progress.finish(self.step, &self.example_name);
497 }
498}
499
500struct ProgressLogger;
501
502impl Log for ProgressLogger {
503 fn enabled(&self, metadata: &Metadata) -> bool {
504 metadata.level() <= Level::Info
505 }
506
507 fn log(&self, record: &Record) {
508 if !self.enabled(record.metadata()) {
509 return;
510 }
511
512 let level_color = match record.level() {
513 Level::Error => "\x1b[31m",
514 Level::Warn => "\x1b[33m",
515 Level::Info => "\x1b[32m",
516 Level::Debug => "\x1b[34m",
517 Level::Trace => "\x1b[35m",
518 };
519 let reset = "\x1b[0m";
520 let bold = "\x1b[1m";
521
522 let level_label = match record.level() {
523 Level::Error => "Error",
524 Level::Warn => "Warn",
525 Level::Info => "Info",
526 Level::Debug => "Debug",
527 Level::Trace => "Trace",
528 };
529
530 let message = format!(
531 "{bold}{level_color}{level_label:>12}{reset} {}",
532 record.args()
533 );
534
535 if let Some(progress) = GLOBAL.get() {
536 progress.log(&message);
537 } else {
538 eprintln!("{}", message);
539 }
540 }
541
542 fn flush(&self) {
543 let _ = std::io::stderr().flush();
544 }
545}
546
547#[cfg(unix)]
548fn get_terminal_width() -> usize {
549 unsafe {
550 let mut winsize: libc::winsize = std::mem::zeroed();
551 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
552 && winsize.ws_col > 0
553 {
554 winsize.ws_col as usize
555 } else {
556 80
557 }
558 }
559}
560
561#[cfg(not(unix))]
562fn get_terminal_width() -> usize {
563 80
564}
565
566fn strip_ansi_len(s: &str) -> usize {
567 let mut len = 0;
568 let mut in_escape = false;
569 for c in s.chars() {
570 if c == '\x1b' {
571 in_escape = true;
572 } else if in_escape {
573 if c == 'm' {
574 in_escape = false;
575 }
576 } else {
577 len += 1;
578 }
579 }
580 len
581}
582
583fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
584 if s.len() <= max_len {
585 Cow::Borrowed(s)
586 } else {
587 Cow::Owned(format!("{}…", &s[..max_len.saturating_sub(1)]))
588 }
589}
590
591fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
592 let mut visible_len = 0;
593 let mut in_escape = false;
594 let mut last_byte_index = 0;
595 for (byte_index, c) in s.char_indices() {
596 if c == '\x1b' {
597 in_escape = true;
598 } else if in_escape {
599 if c == 'm' {
600 in_escape = false;
601 }
602 } else {
603 if visible_len >= max_visible_len {
604 return &s[..last_byte_index];
605 }
606 visible_len += 1;
607 }
608 last_byte_index = byte_index + c.len_utf8();
609 }
610 s
611}
612
613fn format_duration(duration: Duration) -> String {
614 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
615
616 let millis = duration.as_millis() as f32;
617 if millis < 1000.0 {
618 format!("{}ms", millis)
619 } else if millis < MINUTE_IN_MILLIS {
620 format!("{:.1}s", millis / 1_000.0)
621 } else {
622 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
623 }
624}