1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use log::{Level, Log, Metadata, Record};
10
11pub struct Progress {
12 inner: Mutex<ProgressInner>,
13}
14
15struct ProgressInner {
16 completed: Vec<CompletedTask>,
17 in_progress: HashMap<String, InProgressTask>,
18 is_tty: bool,
19 terminal_width: usize,
20 max_example_name_len: usize,
21 status_lines_displayed: usize,
22 total_examples: usize,
23 failed_examples: usize,
24 last_line_is_logging: bool,
25 ticker: Option<std::thread::JoinHandle<()>>,
26}
27
28#[derive(Clone)]
29struct InProgressTask {
30 step: Step,
31 started_at: Instant,
32 substatus: Option<String>,
33 info: Option<(String, InfoStyle)>,
34}
35
36struct CompletedTask {
37 step: Step,
38 example_name: String,
39 duration: Duration,
40 info: Option<(String, InfoStyle)>,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
44pub enum Step {
45 LoadProject,
46 Context,
47 FormatPrompt,
48 Predict,
49 Score,
50 Synthesize,
51 PullExamples,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
55pub enum InfoStyle {
56 Normal,
57 Warning,
58}
59
60impl Step {
61 pub fn label(&self) -> &'static str {
62 match self {
63 Step::LoadProject => "Load",
64 Step::Context => "Context",
65 Step::FormatPrompt => "Format",
66 Step::Predict => "Predict",
67 Step::Score => "Score",
68 Step::Synthesize => "Synthesize",
69 Step::PullExamples => "Pull",
70 }
71 }
72
73 fn color_code(&self) -> &'static str {
74 match self {
75 Step::LoadProject => "\x1b[33m",
76 Step::Context => "\x1b[35m",
77 Step::FormatPrompt => "\x1b[34m",
78 Step::Predict => "\x1b[32m",
79 Step::Score => "\x1b[31m",
80 Step::Synthesize => "\x1b[36m",
81 Step::PullExamples => "\x1b[36m",
82 }
83 }
84}
85
86static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
87static LOGGER: ProgressLogger = ProgressLogger;
88
89const MARGIN: usize = 4;
90const MAX_STATUS_LINES: usize = 10;
91const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
92
93impl Progress {
94 /// Returns the global Progress instance, initializing it if necessary.
95 pub fn global() -> Arc<Progress> {
96 GLOBAL
97 .get_or_init(|| {
98 let progress = Arc::new(Self {
99 inner: Mutex::new(ProgressInner {
100 completed: Vec::new(),
101 in_progress: HashMap::new(),
102 is_tty: std::env::var("NO_COLOR").is_err()
103 && std::io::stderr().is_terminal(),
104 terminal_width: get_terminal_width(),
105 max_example_name_len: 0,
106 status_lines_displayed: 0,
107 total_examples: 0,
108 failed_examples: 0,
109 last_line_is_logging: false,
110 ticker: None,
111 }),
112 });
113 let _ = log::set_logger(&LOGGER);
114 log::set_max_level(log::LevelFilter::Info);
115 progress
116 })
117 .clone()
118 }
119
120 pub fn set_total_examples(&self, total: usize) {
121 let mut inner = self.inner.lock().unwrap();
122 inner.total_examples = total;
123 }
124
125 pub fn increment_failed(&self) {
126 let mut inner = self.inner.lock().unwrap();
127 inner.failed_examples += 1;
128 }
129
130 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
131 /// This should be used for any output that needs to appear above the status lines.
132 fn log(&self, message: &str) {
133 let mut inner = self.inner.lock().unwrap();
134 Self::clear_status_lines(&mut inner);
135
136 if !inner.last_line_is_logging {
137 let reset = "\x1b[0m";
138 let dim = "\x1b[2m";
139 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
140 eprintln!("{dim}{divider}{reset}");
141 inner.last_line_is_logging = true;
142 }
143
144 eprintln!("{}", message);
145 }
146
147 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
148 let mut inner = self.inner.lock().unwrap();
149
150 Self::clear_status_lines(&mut inner);
151
152 let max_name_width = inner
153 .terminal_width
154 .saturating_sub(MARGIN * 2)
155 .saturating_div(3)
156 .max(1);
157 inner.max_example_name_len = inner
158 .max_example_name_len
159 .max(example_name.len().min(max_name_width));
160 inner.in_progress.insert(
161 example_name.to_string(),
162 InProgressTask {
163 step,
164 started_at: Instant::now(),
165 substatus: None,
166 info: None,
167 },
168 );
169
170 if inner.is_tty && inner.ticker.is_none() {
171 let progress = self.clone();
172 inner.ticker = Some(std::thread::spawn(move || {
173 loop {
174 std::thread::sleep(STATUS_TICK_INTERVAL);
175
176 let mut inner = progress.inner.lock().unwrap();
177 if inner.in_progress.is_empty() {
178 break;
179 }
180
181 Progress::clear_status_lines(&mut inner);
182 Progress::print_status_lines(&mut inner);
183 }
184 }));
185 }
186
187 Self::print_status_lines(&mut inner);
188
189 StepProgress {
190 progress: self.clone(),
191 step,
192 example_name: example_name.to_string(),
193 }
194 }
195
196 fn finish(&self, step: Step, example_name: &str) {
197 let mut inner = self.inner.lock().unwrap();
198
199 let Some(task) = inner.in_progress.remove(example_name) else {
200 return;
201 };
202
203 if task.step == step {
204 inner.completed.push(CompletedTask {
205 step: task.step,
206 example_name: example_name.to_string(),
207 duration: task.started_at.elapsed(),
208 info: task.info,
209 });
210
211 Self::clear_status_lines(&mut inner);
212 Self::print_logging_closing_divider(&mut inner);
213 if let Some(last_completed) = inner.completed.last() {
214 Self::print_completed(&inner, last_completed);
215 }
216 Self::print_status_lines(&mut inner);
217 } else {
218 inner.in_progress.insert(example_name.to_string(), task);
219 }
220 }
221
222 fn print_logging_closing_divider(inner: &mut ProgressInner) {
223 if inner.last_line_is_logging {
224 let reset = "\x1b[0m";
225 let dim = "\x1b[2m";
226 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
227 eprintln!("{dim}{divider}{reset}");
228 inner.last_line_is_logging = false;
229 }
230 }
231
232 fn clear_status_lines(inner: &mut ProgressInner) {
233 if inner.is_tty && inner.status_lines_displayed > 0 {
234 // Move up and clear each line we previously displayed
235 for _ in 0..inner.status_lines_displayed {
236 eprint!("\x1b[A\x1b[K");
237 }
238 let _ = std::io::stderr().flush();
239 inner.status_lines_displayed = 0;
240 }
241 }
242
243 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
244 let duration = format_duration(task.duration);
245 let name_width = inner.max_example_name_len;
246 let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
247
248 if inner.is_tty {
249 let reset = "\x1b[0m";
250 let bold = "\x1b[1m";
251 let dim = "\x1b[2m";
252
253 let yellow = "\x1b[33m";
254 let info_part = task
255 .info
256 .as_ref()
257 .map(|(s, style)| {
258 if *style == InfoStyle::Warning {
259 format!("{yellow}{s}{reset}")
260 } else {
261 s.to_string()
262 }
263 })
264 .unwrap_or_default();
265
266 let prefix = format!(
267 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
268 color = task.step.color_code(),
269 label = task.step.label(),
270 name = truncated_name,
271 );
272
273 let duration_with_margin = format!("{duration} ");
274 let padding_needed = inner
275 .terminal_width
276 .saturating_sub(MARGIN)
277 .saturating_sub(duration_with_margin.len())
278 .saturating_sub(strip_ansi_len(&prefix));
279 let padding = " ".repeat(padding_needed);
280
281 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
282 } else {
283 let info_part = task
284 .info
285 .as_ref()
286 .map(|(s, _)| format!(" | {}", s))
287 .unwrap_or_default();
288
289 eprintln!(
290 "{label:>12} {name:<name_width$}{info_part} {duration}",
291 label = task.step.label(),
292 name = truncate_with_ellipsis(&task.example_name, name_width),
293 );
294 }
295 }
296
297 fn print_status_lines(inner: &mut ProgressInner) {
298 if !inner.is_tty || inner.in_progress.is_empty() {
299 inner.status_lines_displayed = 0;
300 return;
301 }
302
303 let reset = "\x1b[0m";
304 let bold = "\x1b[1m";
305 let dim = "\x1b[2m";
306
307 // Build the done/in-progress/total label
308 let done_count = inner.completed.len();
309 let in_progress_count = inner.in_progress.len();
310 let failed_count = inner.failed_examples;
311
312 let failed_label = if failed_count > 0 {
313 format!(" {} failed ", failed_count)
314 } else {
315 String::new()
316 };
317
318 let range_label = format!(
319 " {}/{}/{} ",
320 done_count, in_progress_count, inner.total_examples
321 );
322
323 // Print a divider line with failed count on left, range label on right
324 let failed_visible_len = strip_ansi_len(&failed_label);
325 let range_visible_len = range_label.len();
326 let middle_divider_len = inner
327 .terminal_width
328 .saturating_sub(MARGIN * 2)
329 .saturating_sub(failed_visible_len)
330 .saturating_sub(range_visible_len);
331 let left_divider = "─".repeat(MARGIN);
332 let middle_divider = "─".repeat(middle_divider_len);
333 let right_divider = "─".repeat(MARGIN);
334 eprintln!(
335 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
336 );
337
338 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
339 tasks.sort_by_key(|(name, _)| *name);
340
341 let total_tasks = tasks.len();
342 let mut lines_printed = 0;
343
344 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
345 let elapsed = format_duration(task.started_at.elapsed());
346 let substatus_part = task
347 .substatus
348 .as_ref()
349 .map(|s| truncate_with_ellipsis(s, 30))
350 .unwrap_or_default();
351
352 let step_label = task.step.label();
353 let step_color = task.step.color_code();
354 let name_width = inner.max_example_name_len;
355 let truncated_name = truncate_with_ellipsis(name, name_width);
356
357 let prefix = format!(
358 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
359 name = truncated_name,
360 );
361
362 let duration_with_margin = format!("{elapsed} ");
363 let padding_needed = inner
364 .terminal_width
365 .saturating_sub(MARGIN)
366 .saturating_sub(duration_with_margin.len())
367 .saturating_sub(strip_ansi_len(&prefix));
368 let padding = " ".repeat(padding_needed);
369
370 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
371 lines_printed += 1;
372 }
373
374 // Show "+N more" on its own line if there are more tasks
375 if total_tasks > MAX_STATUS_LINES {
376 let remaining = total_tasks - MAX_STATUS_LINES;
377 eprintln!("{:>12} +{remaining} more", "");
378 lines_printed += 1;
379 }
380
381 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
382 let _ = std::io::stderr().flush();
383 }
384
385 pub fn finalize(&self) {
386 let ticker = {
387 let mut inner = self.inner.lock().unwrap();
388 inner.ticker.take()
389 };
390
391 if let Some(ticker) = ticker {
392 let _ = ticker.join();
393 }
394
395 let mut inner = self.inner.lock().unwrap();
396 Self::clear_status_lines(&mut inner);
397
398 // Print summary if there were failures
399 if inner.failed_examples > 0 {
400 let total_examples = inner.total_examples;
401 let percentage = if total_examples > 0 {
402 inner.failed_examples as f64 / total_examples as f64 * 100.0
403 } else {
404 0.0
405 };
406 eprintln!(
407 "\n{} of {} examples failed ({:.1}%)",
408 inner.failed_examples, total_examples, percentage
409 );
410 }
411 }
412}
413
414pub struct StepProgress {
415 progress: Arc<Progress>,
416 step: Step,
417 example_name: String,
418}
419
420impl StepProgress {
421 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
422 let mut inner = self.progress.inner.lock().unwrap();
423 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
424 task.substatus = Some(substatus.into().into_owned());
425 Progress::clear_status_lines(&mut inner);
426 Progress::print_status_lines(&mut inner);
427 }
428 }
429
430 pub fn clear_substatus(&self) {
431 let mut inner = self.progress.inner.lock().unwrap();
432 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
433 task.substatus = None;
434 Progress::clear_status_lines(&mut inner);
435 Progress::print_status_lines(&mut inner);
436 }
437 }
438
439 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
440 let mut inner = self.progress.inner.lock().unwrap();
441 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
442 task.info = Some((info.into(), style));
443 }
444 }
445}
446
447impl Drop for StepProgress {
448 fn drop(&mut self) {
449 self.progress.finish(self.step, &self.example_name);
450 }
451}
452
453struct ProgressLogger;
454
455impl Log for ProgressLogger {
456 fn enabled(&self, metadata: &Metadata) -> bool {
457 metadata.level() <= Level::Info
458 }
459
460 fn log(&self, record: &Record) {
461 if !self.enabled(record.metadata()) {
462 return;
463 }
464
465 let level_color = match record.level() {
466 Level::Error => "\x1b[31m",
467 Level::Warn => "\x1b[33m",
468 Level::Info => "\x1b[32m",
469 Level::Debug => "\x1b[34m",
470 Level::Trace => "\x1b[35m",
471 };
472 let reset = "\x1b[0m";
473 let bold = "\x1b[1m";
474
475 let level_label = match record.level() {
476 Level::Error => "Error",
477 Level::Warn => "Warn",
478 Level::Info => "Info",
479 Level::Debug => "Debug",
480 Level::Trace => "Trace",
481 };
482
483 let message = format!(
484 "{bold}{level_color}{level_label:>12}{reset} {}",
485 record.args()
486 );
487
488 if let Some(progress) = GLOBAL.get() {
489 progress.log(&message);
490 } else {
491 eprintln!("{}", message);
492 }
493 }
494
495 fn flush(&self) {
496 let _ = std::io::stderr().flush();
497 }
498}
499
500#[cfg(unix)]
501fn get_terminal_width() -> usize {
502 unsafe {
503 let mut winsize: libc::winsize = std::mem::zeroed();
504 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
505 && winsize.ws_col > 0
506 {
507 winsize.ws_col as usize
508 } else {
509 80
510 }
511 }
512}
513
514#[cfg(not(unix))]
515fn get_terminal_width() -> usize {
516 80
517}
518
519fn strip_ansi_len(s: &str) -> usize {
520 let mut len = 0;
521 let mut in_escape = false;
522 for c in s.chars() {
523 if c == '\x1b' {
524 in_escape = true;
525 } else if in_escape {
526 if c == 'm' {
527 in_escape = false;
528 }
529 } else {
530 len += 1;
531 }
532 }
533 len
534}
535
536fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
537 if s.len() <= max_len {
538 s.to_string()
539 } else {
540 format!("{}…", &s[..max_len.saturating_sub(1)])
541 }
542}
543
544fn format_duration(duration: Duration) -> String {
545 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
546
547 let millis = duration.as_millis() as f32;
548 if millis < 1000.0 {
549 format!("{}ms", millis)
550 } else if millis < MINUTE_IN_MILLIS {
551 format!("{:.1}s", millis / 1_000.0)
552 } else {
553 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
554 }
555}