1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use log::{Level, Log, Metadata, Record};
10
11pub struct Progress {
12 inner: Mutex<ProgressInner>,
13}
14
15struct ProgressInner {
16 completed: Vec<CompletedTask>,
17 in_progress: HashMap<String, InProgressTask>,
18 is_tty: bool,
19 terminal_width: usize,
20 max_example_name_len: usize,
21 status_lines_displayed: usize,
22 total_examples: usize,
23 failed_examples: usize,
24 last_line_is_logging: bool,
25 ticker: Option<std::thread::JoinHandle<()>>,
26}
27
28#[derive(Clone)]
29struct InProgressTask {
30 step: Step,
31 started_at: Instant,
32 substatus: Option<String>,
33 info: Option<(String, InfoStyle)>,
34}
35
36struct CompletedTask {
37 step: Step,
38 example_name: String,
39 duration: Duration,
40 info: Option<(String, InfoStyle)>,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
44pub enum Step {
45 LoadProject,
46 Context,
47 FormatPrompt,
48 Predict,
49 Score,
50 Synthesize,
51 PullExamples,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
55pub enum InfoStyle {
56 Normal,
57 Warning,
58}
59
60impl Step {
61 pub fn label(&self) -> &'static str {
62 match self {
63 Step::LoadProject => "Load",
64 Step::Context => "Context",
65 Step::FormatPrompt => "Format",
66 Step::Predict => "Predict",
67 Step::Score => "Score",
68 Step::Synthesize => "Synthesize",
69 Step::PullExamples => "Pull",
70 }
71 }
72
73 fn color_code(&self) -> &'static str {
74 match self {
75 Step::LoadProject => "\x1b[33m",
76 Step::Context => "\x1b[35m",
77 Step::FormatPrompt => "\x1b[34m",
78 Step::Predict => "\x1b[32m",
79 Step::Score => "\x1b[31m",
80 Step::Synthesize => "\x1b[36m",
81 Step::PullExamples => "\x1b[36m",
82 }
83 }
84}
85
86static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
87static LOGGER: ProgressLogger = ProgressLogger;
88
89const MARGIN: usize = 4;
90const MAX_STATUS_LINES: usize = 10;
91const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
92
93impl Progress {
94 /// Returns the global Progress instance, initializing it if necessary.
95 pub fn global() -> Arc<Progress> {
96 GLOBAL
97 .get_or_init(|| {
98 let progress = Arc::new(Self {
99 inner: Mutex::new(ProgressInner {
100 completed: Vec::new(),
101 in_progress: HashMap::new(),
102 is_tty: std::env::var("NO_COLOR").is_err()
103 && std::io::stderr().is_terminal(),
104 terminal_width: get_terminal_width(),
105 max_example_name_len: 0,
106 status_lines_displayed: 0,
107 total_examples: 0,
108 failed_examples: 0,
109 last_line_is_logging: false,
110 ticker: None,
111 }),
112 });
113 let _ = log::set_logger(&LOGGER);
114 log::set_max_level(log::LevelFilter::Info);
115 progress
116 })
117 .clone()
118 }
119
120 pub fn set_total_examples(&self, total: usize) {
121 let mut inner = self.inner.lock().unwrap();
122 inner.total_examples = total;
123 }
124
125 pub fn increment_failed(&self) {
126 let mut inner = self.inner.lock().unwrap();
127 inner.failed_examples += 1;
128 }
129
130 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
131 /// This should be used for any output that needs to appear above the status lines.
132 fn log(&self, message: &str) {
133 let mut inner = self.inner.lock().unwrap();
134 Self::clear_status_lines(&mut inner);
135
136 if !inner.last_line_is_logging {
137 let reset = "\x1b[0m";
138 let dim = "\x1b[2m";
139 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
140 eprintln!("{dim}{divider}{reset}");
141 inner.last_line_is_logging = true;
142 }
143
144 let max_width = inner.terminal_width.saturating_sub(MARGIN);
145 for line in message.lines() {
146 let truncated = truncate_to_visible_width(line, max_width);
147 if truncated.len() < line.len() {
148 eprintln!("{}…", truncated);
149 } else {
150 eprintln!("{}", truncated);
151 }
152 }
153
154 Self::print_status_lines(&mut inner);
155 }
156
157 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
158 let mut inner = self.inner.lock().unwrap();
159
160 Self::clear_status_lines(&mut inner);
161
162 let max_name_width = inner
163 .terminal_width
164 .saturating_sub(MARGIN * 2)
165 .saturating_div(3)
166 .max(1);
167 inner.max_example_name_len = inner
168 .max_example_name_len
169 .max(example_name.len().min(max_name_width));
170 inner.in_progress.insert(
171 example_name.to_string(),
172 InProgressTask {
173 step,
174 started_at: Instant::now(),
175 substatus: None,
176 info: None,
177 },
178 );
179
180 if inner.is_tty && inner.ticker.is_none() {
181 let progress = self.clone();
182 inner.ticker = Some(std::thread::spawn(move || {
183 loop {
184 std::thread::sleep(STATUS_TICK_INTERVAL);
185
186 let mut inner = progress.inner.lock().unwrap();
187 if inner.in_progress.is_empty() {
188 break;
189 }
190
191 Progress::clear_status_lines(&mut inner);
192 Progress::print_status_lines(&mut inner);
193 }
194 }));
195 }
196
197 Self::print_status_lines(&mut inner);
198
199 StepProgress {
200 progress: self.clone(),
201 step,
202 example_name: example_name.to_string(),
203 }
204 }
205
206 fn finish(&self, step: Step, example_name: &str) {
207 let mut inner = self.inner.lock().unwrap();
208
209 let Some(task) = inner.in_progress.remove(example_name) else {
210 return;
211 };
212
213 if task.step == step {
214 inner.completed.push(CompletedTask {
215 step: task.step,
216 example_name: example_name.to_string(),
217 duration: task.started_at.elapsed(),
218 info: task.info,
219 });
220
221 Self::clear_status_lines(&mut inner);
222 Self::print_logging_closing_divider(&mut inner);
223 if let Some(last_completed) = inner.completed.last() {
224 Self::print_completed(&inner, last_completed);
225 }
226 Self::print_status_lines(&mut inner);
227 } else {
228 inner.in_progress.insert(example_name.to_string(), task);
229 }
230 }
231
232 fn print_logging_closing_divider(inner: &mut ProgressInner) {
233 if inner.last_line_is_logging {
234 let reset = "\x1b[0m";
235 let dim = "\x1b[2m";
236 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
237 eprintln!("{dim}{divider}{reset}");
238 inner.last_line_is_logging = false;
239 }
240 }
241
242 fn clear_status_lines(inner: &mut ProgressInner) {
243 if inner.is_tty && inner.status_lines_displayed > 0 {
244 // Move up and clear each line we previously displayed
245 for _ in 0..inner.status_lines_displayed {
246 eprint!("\x1b[A\x1b[K");
247 }
248 let _ = std::io::stderr().flush();
249 inner.status_lines_displayed = 0;
250 }
251 }
252
253 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
254 let duration = format_duration(task.duration);
255 let name_width = inner.max_example_name_len;
256 let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
257
258 if inner.is_tty {
259 let reset = "\x1b[0m";
260 let bold = "\x1b[1m";
261 let dim = "\x1b[2m";
262
263 let yellow = "\x1b[33m";
264 let info_part = task
265 .info
266 .as_ref()
267 .map(|(s, style)| {
268 if *style == InfoStyle::Warning {
269 format!("{yellow}{s}{reset}")
270 } else {
271 s.to_string()
272 }
273 })
274 .unwrap_or_default();
275
276 let prefix = format!(
277 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
278 color = task.step.color_code(),
279 label = task.step.label(),
280 name = truncated_name,
281 );
282
283 let duration_with_margin = format!("{duration} ");
284 let padding_needed = inner
285 .terminal_width
286 .saturating_sub(MARGIN)
287 .saturating_sub(duration_with_margin.len())
288 .saturating_sub(strip_ansi_len(&prefix));
289 let padding = " ".repeat(padding_needed);
290
291 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
292 } else {
293 let info_part = task
294 .info
295 .as_ref()
296 .map(|(s, _)| format!(" | {}", s))
297 .unwrap_or_default();
298
299 eprintln!(
300 "{label:>12} {name:<name_width$}{info_part} {duration}",
301 label = task.step.label(),
302 name = truncate_with_ellipsis(&task.example_name, name_width),
303 );
304 }
305 }
306
307 fn print_status_lines(inner: &mut ProgressInner) {
308 if !inner.is_tty || inner.in_progress.is_empty() {
309 inner.status_lines_displayed = 0;
310 return;
311 }
312
313 let reset = "\x1b[0m";
314 let bold = "\x1b[1m";
315 let dim = "\x1b[2m";
316
317 // Build the done/in-progress/total label
318 let done_count = inner.completed.len();
319 let in_progress_count = inner.in_progress.len();
320 let failed_count = inner.failed_examples;
321
322 let failed_label = if failed_count > 0 {
323 format!(" {} failed ", failed_count)
324 } else {
325 String::new()
326 };
327
328 let range_label = format!(
329 " {}/{}/{} ",
330 done_count, in_progress_count, inner.total_examples
331 );
332
333 // Print a divider line with failed count on left, range label on right
334 let failed_visible_len = strip_ansi_len(&failed_label);
335 let range_visible_len = range_label.len();
336 let middle_divider_len = inner
337 .terminal_width
338 .saturating_sub(MARGIN * 2)
339 .saturating_sub(failed_visible_len)
340 .saturating_sub(range_visible_len);
341 let left_divider = "─".repeat(MARGIN);
342 let middle_divider = "─".repeat(middle_divider_len);
343 let right_divider = "─".repeat(MARGIN);
344 eprintln!(
345 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
346 );
347
348 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
349 tasks.sort_by_key(|(name, _)| *name);
350
351 let total_tasks = tasks.len();
352 let mut lines_printed = 0;
353
354 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
355 let elapsed = format_duration(task.started_at.elapsed());
356 let substatus_part = task
357 .substatus
358 .as_ref()
359 .map(|s| truncate_with_ellipsis(s, 30))
360 .unwrap_or_default();
361
362 let step_label = task.step.label();
363 let step_color = task.step.color_code();
364 let name_width = inner.max_example_name_len;
365 let truncated_name = truncate_with_ellipsis(name, name_width);
366
367 let prefix = format!(
368 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
369 name = truncated_name,
370 );
371
372 let duration_with_margin = format!("{elapsed} ");
373 let padding_needed = inner
374 .terminal_width
375 .saturating_sub(MARGIN)
376 .saturating_sub(duration_with_margin.len())
377 .saturating_sub(strip_ansi_len(&prefix));
378 let padding = " ".repeat(padding_needed);
379
380 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
381 lines_printed += 1;
382 }
383
384 // Show "+N more" on its own line if there are more tasks
385 if total_tasks > MAX_STATUS_LINES {
386 let remaining = total_tasks - MAX_STATUS_LINES;
387 eprintln!("{:>12} +{remaining} more", "");
388 lines_printed += 1;
389 }
390
391 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
392 let _ = std::io::stderr().flush();
393 }
394
395 pub fn finalize(&self) {
396 let ticker = {
397 let mut inner = self.inner.lock().unwrap();
398 inner.ticker.take()
399 };
400
401 if let Some(ticker) = ticker {
402 let _ = ticker.join();
403 }
404
405 let mut inner = self.inner.lock().unwrap();
406 Self::clear_status_lines(&mut inner);
407
408 // Print summary if there were failures
409 if inner.failed_examples > 0 {
410 let total_examples = inner.total_examples;
411 let percentage = if total_examples > 0 {
412 inner.failed_examples as f64 / total_examples as f64 * 100.0
413 } else {
414 0.0
415 };
416 eprintln!(
417 "\n{} of {} examples failed ({:.1}%)",
418 inner.failed_examples, total_examples, percentage
419 );
420 }
421 }
422}
423
424pub struct StepProgress {
425 progress: Arc<Progress>,
426 step: Step,
427 example_name: String,
428}
429
430impl StepProgress {
431 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
432 let mut inner = self.progress.inner.lock().unwrap();
433 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
434 task.substatus = Some(substatus.into().into_owned());
435 Progress::clear_status_lines(&mut inner);
436 Progress::print_status_lines(&mut inner);
437 }
438 }
439
440 pub fn clear_substatus(&self) {
441 let mut inner = self.progress.inner.lock().unwrap();
442 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
443 task.substatus = None;
444 Progress::clear_status_lines(&mut inner);
445 Progress::print_status_lines(&mut inner);
446 }
447 }
448
449 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
450 let mut inner = self.progress.inner.lock().unwrap();
451 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
452 task.info = Some((info.into(), style));
453 }
454 }
455}
456
457impl Drop for StepProgress {
458 fn drop(&mut self) {
459 self.progress.finish(self.step, &self.example_name);
460 }
461}
462
463struct ProgressLogger;
464
465impl Log for ProgressLogger {
466 fn enabled(&self, metadata: &Metadata) -> bool {
467 metadata.level() <= Level::Info
468 }
469
470 fn log(&self, record: &Record) {
471 if !self.enabled(record.metadata()) {
472 return;
473 }
474
475 let level_color = match record.level() {
476 Level::Error => "\x1b[31m",
477 Level::Warn => "\x1b[33m",
478 Level::Info => "\x1b[32m",
479 Level::Debug => "\x1b[34m",
480 Level::Trace => "\x1b[35m",
481 };
482 let reset = "\x1b[0m";
483 let bold = "\x1b[1m";
484
485 let level_label = match record.level() {
486 Level::Error => "Error",
487 Level::Warn => "Warn",
488 Level::Info => "Info",
489 Level::Debug => "Debug",
490 Level::Trace => "Trace",
491 };
492
493 let message = format!(
494 "{bold}{level_color}{level_label:>12}{reset} {}",
495 record.args()
496 );
497
498 if let Some(progress) = GLOBAL.get() {
499 progress.log(&message);
500 } else {
501 eprintln!("{}", message);
502 }
503 }
504
505 fn flush(&self) {
506 let _ = std::io::stderr().flush();
507 }
508}
509
510#[cfg(unix)]
511fn get_terminal_width() -> usize {
512 unsafe {
513 let mut winsize: libc::winsize = std::mem::zeroed();
514 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
515 && winsize.ws_col > 0
516 {
517 winsize.ws_col as usize
518 } else {
519 80
520 }
521 }
522}
523
524#[cfg(not(unix))]
525fn get_terminal_width() -> usize {
526 80
527}
528
529fn strip_ansi_len(s: &str) -> usize {
530 let mut len = 0;
531 let mut in_escape = false;
532 for c in s.chars() {
533 if c == '\x1b' {
534 in_escape = true;
535 } else if in_escape {
536 if c == 'm' {
537 in_escape = false;
538 }
539 } else {
540 len += 1;
541 }
542 }
543 len
544}
545
546fn truncate_with_ellipsis(s: &str, max_len: usize) -> Cow<'_, str> {
547 if s.len() <= max_len {
548 Cow::Borrowed(s)
549 } else {
550 Cow::Owned(format!("{}…", &s[..max_len.saturating_sub(1)]))
551 }
552}
553
554fn truncate_to_visible_width(s: &str, max_visible_len: usize) -> &str {
555 let mut visible_len = 0;
556 let mut in_escape = false;
557 let mut last_byte_index = 0;
558 for (byte_index, c) in s.char_indices() {
559 if c == '\x1b' {
560 in_escape = true;
561 } else if in_escape {
562 if c == 'm' {
563 in_escape = false;
564 }
565 } else {
566 if visible_len >= max_visible_len {
567 return &s[..last_byte_index];
568 }
569 visible_len += 1;
570 }
571 last_byte_index = byte_index + c.len_utf8();
572 }
573 s
574}
575
576fn format_duration(duration: Duration) -> String {
577 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
578
579 let millis = duration.as_millis() as f32;
580 if millis < 1000.0 {
581 format!("{}ms", millis)
582 } else if millis < MINUTE_IN_MILLIS {
583 format!("{:.1}s", millis / 1_000.0)
584 } else {
585 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
586 }
587}