1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use crate::paths::RUN_DIR;
10
11use log::{Level, Log, Metadata, Record};
12
13pub struct Progress {
14 inner: Mutex<ProgressInner>,
15}
16
17struct ProgressInner {
18 completed: Vec<CompletedTask>,
19 in_progress: HashMap<String, InProgressTask>,
20 is_tty: bool,
21 terminal_width: usize,
22 max_example_name_len: usize,
23 status_lines_displayed: usize,
24 total_examples: usize,
25 failed_examples: usize,
26 last_line_is_logging: bool,
27 ticker: Option<std::thread::JoinHandle<()>>,
28}
29
30#[derive(Clone)]
31struct InProgressTask {
32 step: Step,
33 started_at: Instant,
34 substatus: Option<String>,
35 info: Option<(String, InfoStyle)>,
36}
37
38struct CompletedTask {
39 step: Step,
40 example_name: String,
41 duration: Duration,
42 info: Option<(String, InfoStyle)>,
43}
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
46pub enum Step {
47 LoadProject,
48 Context,
49 FormatPrompt,
50 Predict,
51 Score,
52 Synthesize,
53 PullExamples,
54}
55
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum InfoStyle {
58 Normal,
59 Warning,
60}
61
62impl Step {
63 pub fn label(&self) -> &'static str {
64 match self {
65 Step::LoadProject => "Load",
66 Step::Context => "Context",
67 Step::FormatPrompt => "Format",
68 Step::Predict => "Predict",
69 Step::Score => "Score",
70 Step::Synthesize => "Synthesize",
71 Step::PullExamples => "Pull",
72 }
73 }
74
75 fn color_code(&self) -> &'static str {
76 match self {
77 Step::LoadProject => "\x1b[33m",
78 Step::Context => "\x1b[35m",
79 Step::FormatPrompt => "\x1b[34m",
80 Step::Predict => "\x1b[32m",
81 Step::Score => "\x1b[31m",
82 Step::Synthesize => "\x1b[36m",
83 Step::PullExamples => "\x1b[36m",
84 }
85 }
86}
87
88static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
89static LOGGER: ProgressLogger = ProgressLogger;
90
91const MARGIN: usize = 4;
92const MAX_STATUS_LINES: usize = 10;
93const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
94
95impl Progress {
96 /// Returns the global Progress instance, initializing it if necessary.
97 pub fn global() -> Arc<Progress> {
98 GLOBAL
99 .get_or_init(|| {
100 let progress = Arc::new(Self {
101 inner: Mutex::new(ProgressInner {
102 completed: Vec::new(),
103 in_progress: HashMap::new(),
104 is_tty: std::env::var("NO_COLOR").is_err()
105 && std::io::stderr().is_terminal(),
106 terminal_width: get_terminal_width(),
107 max_example_name_len: 0,
108 status_lines_displayed: 0,
109 total_examples: 0,
110 failed_examples: 0,
111 last_line_is_logging: false,
112 ticker: None,
113 }),
114 });
115 let _ = log::set_logger(&LOGGER);
116 log::set_max_level(log::LevelFilter::Info);
117 progress
118 })
119 .clone()
120 }
121
122 pub fn set_total_examples(&self, total: usize) {
123 let mut inner = self.inner.lock().unwrap();
124 inner.total_examples = total;
125 }
126
127 pub fn increment_failed(&self) {
128 let mut inner = self.inner.lock().unwrap();
129 inner.failed_examples += 1;
130 }
131
132 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
133 /// This should be used for any output that needs to appear above the status lines.
134 fn log(&self, message: &str) {
135 let mut inner = self.inner.lock().unwrap();
136 Self::clear_status_lines(&mut inner);
137
138 if !inner.last_line_is_logging {
139 let reset = "\x1b[0m";
140 let dim = "\x1b[2m";
141 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
142 eprintln!("{dim}{divider}{reset}");
143 inner.last_line_is_logging = true;
144 }
145
146 let max_width = inner.terminal_width.saturating_sub(MARGIN);
147 for line in message.lines() {
148 let truncated = truncate_to_visible_width(line, max_width);
149 if truncated.len() < line.len() {
150 eprintln!("{}…", truncated);
151 } else {
152 eprintln!("{}", truncated);
153 }
154 }
155
156 Self::print_status_lines(&mut inner);
157 }
158
159 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
160 let mut inner = self.inner.lock().unwrap();
161
162 Self::clear_status_lines(&mut inner);
163
164 let max_name_width = inner
165 .terminal_width
166 .saturating_sub(MARGIN * 2)
167 .saturating_div(3)
168 .max(1);
169 inner.max_example_name_len = inner
170 .max_example_name_len
171 .max(example_name.len().min(max_name_width));
172 inner.in_progress.insert(
173 example_name.to_string(),
174 InProgressTask {
175 step,
176 started_at: Instant::now(),
177 substatus: None,
178 info: None,
179 },
180 );
181
182 if inner.is_tty && inner.ticker.is_none() {
183 let progress = self.clone();
184 inner.ticker = Some(std::thread::spawn(move || {
185 loop {
186 std::thread::sleep(STATUS_TICK_INTERVAL);
187
188 let mut inner = progress.inner.lock().unwrap();
189 if inner.in_progress.is_empty() {
190 break;
191 }
192
193 Progress::clear_status_lines(&mut inner);
194 Progress::print_status_lines(&mut inner);
195 }
196 }));
197 }
198
199 Self::print_status_lines(&mut inner);
200
201 StepProgress {
202 progress: self.clone(),
203 step,
204 example_name: example_name.to_string(),
205 }
206 }
207
208 fn finish(&self, step: Step, example_name: &str) {
209 let mut inner = self.inner.lock().unwrap();
210
211 let Some(task) = inner.in_progress.remove(example_name) else {
212 return;
213 };
214
215 if task.step == step {
216 inner.completed.push(CompletedTask {
217 step: task.step,
218 example_name: example_name.to_string(),
219 duration: task.started_at.elapsed(),
220 info: task.info,
221 });
222
223 Self::clear_status_lines(&mut inner);
224 Self::print_logging_closing_divider(&mut inner);
225 if let Some(last_completed) = inner.completed.last() {
226 Self::print_completed(&inner, last_completed);
227 }
228 Self::print_status_lines(&mut inner);
229 } else {
230 inner.in_progress.insert(example_name.to_string(), task);
231 }
232 }
233
234 fn print_logging_closing_divider(inner: &mut ProgressInner) {
235 if inner.last_line_is_logging {
236 let reset = "\x1b[0m";
237 let dim = "\x1b[2m";
238 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
239 eprintln!("{dim}{divider}{reset}");
240 inner.last_line_is_logging = false;
241 }
242 }
243
244 fn clear_status_lines(inner: &mut ProgressInner) {
245 if inner.is_tty && inner.status_lines_displayed > 0 {
246 // Move up and clear each line we previously displayed
247 for _ in 0..inner.status_lines_displayed {
248 eprint!("\x1b[A\x1b[K");
249 }
250 let _ = std::io::stderr().flush();
251 inner.status_lines_displayed = 0;
252 }
253 }
254
255 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
256 let duration = format_duration(task.duration);
257 let name_width = inner.max_example_name_len;
258 let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
259
260 if inner.is_tty {
261 let reset = "\x1b[0m";
262 let bold = "\x1b[1m";
263 let dim = "\x1b[2m";
264
265 let yellow = "\x1b[33m";
266 let info_part = task
267 .info
268 .as_ref()
269 .map(|(s, style)| {
270 if *style == InfoStyle::Warning {
271 format!("{yellow}{s}{reset}")
272 } else {
273 s.to_string()
274 }
275 })
276 .unwrap_or_default();
277
278 let prefix = format!(
279 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
280 color = task.step.color_code(),
281 label = task.step.label(),
282 name = truncated_name,
283 );
284
285 let duration_with_margin = format!("{duration} ");
286 let padding_needed = inner
287 .terminal_width
288 .saturating_sub(MARGIN)
289 .saturating_sub(duration_with_margin.len())
290 .saturating_sub(strip_ansi_len(&prefix));
291 let padding = " ".repeat(padding_needed);
292
293 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
294 } else {
295 let info_part = task
296 .info
297 .as_ref()
298 .map(|(s, _)| format!(" | {}", s))
299 .unwrap_or_default();
300
301 eprintln!(
302 "{label:>12} {name:<name_width$}{info_part} {duration}",
303 label = task.step.label(),
304 name = truncate_with_ellipsis(&task.example_name, name_width),
305 );
306 }
307 }
308
309 fn print_status_lines(inner: &mut ProgressInner) {
310 if !inner.is_tty || inner.in_progress.is_empty() {
311 inner.status_lines_displayed = 0;
312 return;
313 }
314
315 let reset = "\x1b[0m";
316 let bold = "\x1b[1m";
317 let dim = "\x1b[2m";
318
319 // Build the done/in-progress/total label
320 let done_count = inner.completed.len();
321 let in_progress_count = inner.in_progress.len();
322 let failed_count = inner.failed_examples;
323
324 let failed_label = if failed_count > 0 {
325 format!(" {} failed ", failed_count)
326 } else {
327 String::new()
328 };
329
330 let range_label = format!(
331 " {}/{}/{} ",
332 done_count, in_progress_count, inner.total_examples
333 );
334
335 // Print a divider line with failed count on left, range label on right
336 let failed_visible_len = strip_ansi_len(&failed_label);
337 let range_visible_len = range_label.len();
338 let middle_divider_len = inner
339 .terminal_width
340 .saturating_sub(MARGIN * 2)
341 .saturating_sub(failed_visible_len)
342 .saturating_sub(range_visible_len);
343 let left_divider = "─".repeat(MARGIN);
344 let middle_divider = "─".repeat(middle_divider_len);
345 let right_divider = "─".repeat(MARGIN);
346 eprintln!(
347 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
348 );
349
350 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
351 tasks.sort_by_key(|(name, _)| *name);
352
353 let total_tasks = tasks.len();
354 let mut lines_printed = 0;
355
356 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
357 let elapsed = format_duration(task.started_at.elapsed());
358 let substatus_part = task
359 .substatus
360 .as_ref()
361 .map(|s| truncate_with_ellipsis(s, 30))
362 .unwrap_or_default();
363
364 let step_label = task.step.label();
365 let step_color = task.step.color_code();
366 let name_width = inner.max_example_name_len;
367 let truncated_name = truncate_with_ellipsis(name, name_width);
368
369 let prefix = format!(
370 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
371 name = truncated_name,
372 );
373
374 let duration_with_margin = format!("{elapsed} ");
375 let padding_needed = inner
376 .terminal_width
377 .saturating_sub(MARGIN)
378 .saturating_sub(duration_with_margin.len())
379 .saturating_sub(strip_ansi_len(&prefix));
380 let padding = " ".repeat(padding_needed);
381
382 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
383 lines_printed += 1;
384 }
385
386 // Show "+N more" on its own line if there are more tasks
387 if total_tasks > MAX_STATUS_LINES {
388 let remaining = total_tasks - MAX_STATUS_LINES;
389 eprintln!("{:>12} +{remaining} more", "");
390 lines_printed += 1;
391 }
392
393 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
394 let _ = std::io::stderr().flush();
395 }
396
397 pub fn finalize(&self) {
398 let ticker = {
399 let mut inner = self.inner.lock().unwrap();
400 inner.ticker.take()
401 };
402
403 if let Some(ticker) = ticker {
404 let _ = ticker.join();
405 }
406
407 let mut inner = self.inner.lock().unwrap();
408 Self::clear_status_lines(&mut inner);
409
410 // Print summary if there were failures
411 if inner.failed_examples > 0 {
412 let total_examples = inner.total_examples;
413 let percentage = if total_examples > 0 {
414 inner.failed_examples as f64 / total_examples as f64 * 100.0
415 } else {
416 0.0
417 };
418 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
419 eprintln!(
420 "\n{} of {} examples failed ({:.1}%)\nFailed examples: {}",
421 inner.failed_examples,
422 total_examples,
423 percentage,
424 failed_jsonl_path.display()
425 );
426 }
427 }
428}
429
430pub struct StepProgress {
431 progress: Arc<Progress>,
432 step: Step,
433 example_name: String,
434}
435
436impl StepProgress {
437 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
438 let mut inner = self.progress.inner.lock().unwrap();
439 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
440 task.substatus = Some(substatus.into().into_owned());
441 Progress::clear_status_lines(&mut inner);
442 Progress::print_status_lines(&mut inner);
443 }
444 }
445
446 pub fn clear_substatus(&self) {
447 let mut inner = self.progress.inner.lock().unwrap();
448 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
449 task.substatus = None;
450 Progress::clear_status_lines(&mut inner);
451 Progress::print_status_lines(&mut inner);
452 }
453 }
454
455 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
456 let mut inner = self.progress.inner.lock().unwrap();
457 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
458 task.info = Some((info.into(), style));
459 }
460 }
461}
462
463impl Drop for StepProgress {
464 fn drop(&mut self) {
465 self.progress.finish(self.step, &self.example_name);
466 }
467}
468
469struct ProgressLogger;
470
471impl Log for ProgressLogger {
472 fn enabled(&self, metadata: &Metadata) -> bool {
473 metadata.level() <= Level::Info
474 }
475
476 fn log(&self, record: &Record) {
477 if !self.enabled(record.metadata()) {
478 return;
479 }
480
481 let level_color = match record.level() {
482 Level::Error => "\x1b[31m",
483 Level::Warn => "\x1b[33m",
484 Level::Info => "\x1b[32m",
485 Level::Debug => "\x1b[34m",
486 Level::Trace => "\x1b[35m",
487 };
488 let reset = "\x1b[0m";
489 let bold = "\x1b[1m";
490
491 let level_label = match record.level() {
492 Level::Error => "Error",
493 Level::Warn => "Warn",
494 Level::Info => "Info",
495 Level::Debug => "Debug",
496 Level::Trace => "Trace",
497 };
498
499 let message = format!(
500 "{bold}{level_color}{level_label:>12}{reset} {}",
501 record.args()
502 );
503
504 if let Some(progress) = GLOBAL.get() {
505 progress.log(&message);
506 } else {
507 eprintln!("{}", message);
508 }
509 }
510
511 fn flush(&self) {
512 let _ = std::io::stderr().flush();
513 }
514}
515
516#[cfg(unix)]
517fn get_terminal_width() -> usize {
518 unsafe {
519 let mut winsize: libc::winsize = std::mem::zeroed();
520 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
521 && winsize.ws_col > 0
522 {
523 winsize.ws_col as usize
524 } else {
525 80
526 }
527 }
528}
529
530#[cfg(not(unix))]
531fn get_terminal_width() -> usize {
532 80
533}
534
535fn strip_ansi_len(s: &str) -> usize {
536 let mut len = 0;
537 let mut in_escape = false;
538 for c in s.chars() {
539 if c == '\x1b' {
540 in_escape = true;
541 } else if in_escape {
542 if c == 'm' {
543 in_escape = false;
544 }
545 } else {
546 len += 1;
547 }
548 }
549 len
550}
551
552fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
553 if s.len() <= max_len {
554 Cow::Borrowed(s)
555 } else {
556 Cow::Owned(format!("{}…", &s[..max_len.saturating_sub(1)]))
557 }
558}
559
560fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
561 let mut visible_len = 0;
562 let mut in_escape = false;
563 let mut last_byte_index = 0;
564 for (byte_index, c) in s.char_indices() {
565 if c == '\x1b' {
566 in_escape = true;
567 } else if in_escape {
568 if c == 'm' {
569 in_escape = false;
570 }
571 } else {
572 if visible_len >= max_visible_len {
573 return &s[..last_byte_index];
574 }
575 visible_len += 1;
576 }
577 last_byte_index = byte_index + c.len_utf8();
578 }
579 s
580}
581
582fn format_duration(duration: Duration) -> String {
583 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
584
585 let millis = duration.as_millis() as f32;
586 if millis < 1000.0 {
587 format!("{}ms", millis)
588 } else if millis < MINUTE_IN_MILLIS {
589 format!("{:.1}s", millis / 1_000.0)
590 } else {
591 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
592 }
593}