1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use log::{Level, Log, Metadata, Record};
10
11pub struct Progress {
12 inner: Mutex<ProgressInner>,
13}
14
15struct ProgressInner {
16 completed: Vec<CompletedTask>,
17 in_progress: HashMap<String, InProgressTask>,
18 is_tty: bool,
19 terminal_width: usize,
20 max_example_name_len: usize,
21 status_lines_displayed: usize,
22 total_examples: usize,
23 failed_examples: usize,
24 last_line_is_logging: bool,
25}
26
27#[derive(Clone)]
28struct InProgressTask {
29 step: Step,
30 started_at: Instant,
31 substatus: Option<String>,
32 info: Option<(String, InfoStyle)>,
33}
34
35struct CompletedTask {
36 step: Step,
37 example_name: String,
38 duration: Duration,
39 info: Option<(String, InfoStyle)>,
40}
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
43pub enum Step {
44 LoadProject,
45 Context,
46 FormatPrompt,
47 Predict,
48 Score,
49 Synthesize,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq)]
53pub enum InfoStyle {
54 Normal,
55 Warning,
56}
57
58impl Step {
59 pub fn label(&self) -> &'static str {
60 match self {
61 Step::LoadProject => "Load",
62 Step::Context => "Context",
63 Step::FormatPrompt => "Format",
64 Step::Predict => "Predict",
65 Step::Score => "Score",
66 Step::Synthesize => "Synthesize",
67 }
68 }
69
70 fn color_code(&self) -> &'static str {
71 match self {
72 Step::LoadProject => "\x1b[33m",
73 Step::Context => "\x1b[35m",
74 Step::FormatPrompt => "\x1b[34m",
75 Step::Predict => "\x1b[32m",
76 Step::Score => "\x1b[31m",
77 Step::Synthesize => "\x1b[36m",
78 }
79 }
80}
81
82static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
83static LOGGER: ProgressLogger = ProgressLogger;
84
85const MARGIN: usize = 4;
86const MAX_STATUS_LINES: usize = 10;
87
88impl Progress {
89 /// Returns the global Progress instance, initializing it if necessary.
90 pub fn global() -> Arc<Progress> {
91 GLOBAL
92 .get_or_init(|| {
93 let progress = Arc::new(Self {
94 inner: Mutex::new(ProgressInner {
95 completed: Vec::new(),
96 in_progress: HashMap::new(),
97 is_tty: std::io::stderr().is_terminal(),
98 terminal_width: get_terminal_width(),
99 max_example_name_len: 0,
100 status_lines_displayed: 0,
101 total_examples: 0,
102 failed_examples: 0,
103 last_line_is_logging: false,
104 }),
105 });
106 let _ = log::set_logger(&LOGGER);
107 log::set_max_level(log::LevelFilter::Error);
108 progress
109 })
110 .clone()
111 }
112
113 pub fn set_total_examples(&self, total: usize) {
114 let mut inner = self.inner.lock().unwrap();
115 inner.total_examples = total;
116 }
117
118 pub fn increment_failed(&self) {
119 let mut inner = self.inner.lock().unwrap();
120 inner.failed_examples += 1;
121 }
122
123 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
124 /// This should be used for any output that needs to appear above the status lines.
125 fn log(&self, message: &str) {
126 let mut inner = self.inner.lock().unwrap();
127 Self::clear_status_lines(&mut inner);
128
129 if !inner.last_line_is_logging {
130 let reset = "\x1b[0m";
131 let dim = "\x1b[2m";
132 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
133 eprintln!("{dim}{divider}{reset}");
134 inner.last_line_is_logging = true;
135 }
136
137 eprintln!("{}", message);
138 }
139
140 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
141 let mut inner = self.inner.lock().unwrap();
142
143 Self::clear_status_lines(&mut inner);
144
145 inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
146 inner.in_progress.insert(
147 example_name.to_string(),
148 InProgressTask {
149 step,
150 started_at: Instant::now(),
151 substatus: None,
152 info: None,
153 },
154 );
155
156 Self::print_status_lines(&mut inner);
157
158 StepProgress {
159 progress: self.clone(),
160 step,
161 example_name: example_name.to_string(),
162 }
163 }
164
165 fn finish(&self, step: Step, example_name: &str) {
166 let mut inner = self.inner.lock().unwrap();
167
168 let Some(task) = inner.in_progress.remove(example_name) else {
169 return;
170 };
171
172 if task.step == step {
173 inner.completed.push(CompletedTask {
174 step: task.step,
175 example_name: example_name.to_string(),
176 duration: task.started_at.elapsed(),
177 info: task.info,
178 });
179
180 Self::clear_status_lines(&mut inner);
181 Self::print_logging_closing_divider(&mut inner);
182 Self::print_completed(&inner, inner.completed.last().unwrap());
183 Self::print_status_lines(&mut inner);
184 } else {
185 inner.in_progress.insert(example_name.to_string(), task);
186 }
187 }
188
189 fn print_logging_closing_divider(inner: &mut ProgressInner) {
190 if inner.last_line_is_logging {
191 let reset = "\x1b[0m";
192 let dim = "\x1b[2m";
193 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
194 eprintln!("{dim}{divider}{reset}");
195 inner.last_line_is_logging = false;
196 }
197 }
198
199 fn clear_status_lines(inner: &mut ProgressInner) {
200 if inner.is_tty && inner.status_lines_displayed > 0 {
201 // Move up and clear each line we previously displayed
202 for _ in 0..inner.status_lines_displayed {
203 eprint!("\x1b[A\x1b[K");
204 }
205 let _ = std::io::stderr().flush();
206 inner.status_lines_displayed = 0;
207 }
208 }
209
210 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
211 let duration = format_duration(task.duration);
212 let name_width = inner.max_example_name_len;
213
214 if inner.is_tty {
215 let reset = "\x1b[0m";
216 let bold = "\x1b[1m";
217 let dim = "\x1b[2m";
218
219 let yellow = "\x1b[33m";
220 let info_part = task
221 .info
222 .as_ref()
223 .map(|(s, style)| {
224 if *style == InfoStyle::Warning {
225 format!("{yellow}{s}{reset}")
226 } else {
227 s.to_string()
228 }
229 })
230 .unwrap_or_default();
231
232 let prefix = format!(
233 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
234 color = task.step.color_code(),
235 label = task.step.label(),
236 name = task.example_name,
237 );
238
239 let duration_with_margin = format!("{duration} ");
240 let padding_needed = inner
241 .terminal_width
242 .saturating_sub(MARGIN)
243 .saturating_sub(duration_with_margin.len())
244 .saturating_sub(strip_ansi_len(&prefix));
245 let padding = " ".repeat(padding_needed);
246
247 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
248 } else {
249 let info_part = task
250 .info
251 .as_ref()
252 .map(|(s, _)| format!(" | {}", s))
253 .unwrap_or_default();
254
255 eprintln!(
256 "{label:>12} {name:<name_width$}{info_part} {duration}",
257 label = task.step.label(),
258 name = task.example_name,
259 );
260 }
261 }
262
263 fn print_status_lines(inner: &mut ProgressInner) {
264 if !inner.is_tty || inner.in_progress.is_empty() {
265 inner.status_lines_displayed = 0;
266 return;
267 }
268
269 let reset = "\x1b[0m";
270 let bold = "\x1b[1m";
271 let dim = "\x1b[2m";
272
273 // Build the done/in-progress/total label
274 let done_count = inner.completed.len();
275 let in_progress_count = inner.in_progress.len();
276 let failed_count = inner.failed_examples;
277
278 let failed_label = if failed_count > 0 {
279 format!(" {} failed ", failed_count)
280 } else {
281 String::new()
282 };
283
284 let range_label = format!(
285 " {}/{}/{} ",
286 done_count, in_progress_count, inner.total_examples
287 );
288
289 // Print a divider line with failed count on left, range label on right
290 let failed_visible_len = strip_ansi_len(&failed_label);
291 let range_visible_len = range_label.len();
292 let middle_divider_len = inner
293 .terminal_width
294 .saturating_sub(MARGIN * 2)
295 .saturating_sub(failed_visible_len)
296 .saturating_sub(range_visible_len);
297 let left_divider = "─".repeat(MARGIN);
298 let middle_divider = "─".repeat(middle_divider_len);
299 let right_divider = "─".repeat(MARGIN);
300 eprintln!(
301 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
302 );
303
304 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
305 tasks.sort_by_key(|(name, _)| *name);
306
307 let total_tasks = tasks.len();
308 let mut lines_printed = 0;
309
310 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
311 let elapsed = format_duration(task.started_at.elapsed());
312 let substatus_part = task
313 .substatus
314 .as_ref()
315 .map(|s| truncate_with_ellipsis(s, 30))
316 .unwrap_or_default();
317
318 let step_label = task.step.label();
319 let step_color = task.step.color_code();
320 let name_width = inner.max_example_name_len;
321
322 let prefix = format!(
323 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
324 name = name,
325 );
326
327 let duration_with_margin = format!("{elapsed} ");
328 let padding_needed = inner
329 .terminal_width
330 .saturating_sub(MARGIN)
331 .saturating_sub(duration_with_margin.len())
332 .saturating_sub(strip_ansi_len(&prefix));
333 let padding = " ".repeat(padding_needed);
334
335 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
336 lines_printed += 1;
337 }
338
339 // Show "+N more" on its own line if there are more tasks
340 if total_tasks > MAX_STATUS_LINES {
341 let remaining = total_tasks - MAX_STATUS_LINES;
342 eprintln!("{:>12} +{remaining} more", "");
343 lines_printed += 1;
344 }
345
346 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
347 let _ = std::io::stderr().flush();
348 }
349
350 pub fn finalize(&self) {
351 let mut inner = self.inner.lock().unwrap();
352 Self::clear_status_lines(&mut inner);
353
354 // Print summary if there were failures
355 if inner.failed_examples > 0 {
356 let total_processed = inner.completed.len() + inner.failed_examples;
357 let percentage = if total_processed > 0 {
358 inner.failed_examples as f64 / total_processed as f64 * 100.0
359 } else {
360 0.0
361 };
362 eprintln!(
363 "\n{} of {} examples failed ({:.1}%)",
364 inner.failed_examples, total_processed, percentage
365 );
366 }
367 }
368}
369
370pub struct StepProgress {
371 progress: Arc<Progress>,
372 step: Step,
373 example_name: String,
374}
375
376impl StepProgress {
377 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
378 let mut inner = self.progress.inner.lock().unwrap();
379 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
380 task.substatus = Some(substatus.into().into_owned());
381 Progress::clear_status_lines(&mut inner);
382 Progress::print_status_lines(&mut inner);
383 }
384 }
385
386 pub fn clear_substatus(&self) {
387 let mut inner = self.progress.inner.lock().unwrap();
388 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
389 task.substatus = None;
390 Progress::clear_status_lines(&mut inner);
391 Progress::print_status_lines(&mut inner);
392 }
393 }
394
395 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
396 let mut inner = self.progress.inner.lock().unwrap();
397 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
398 task.info = Some((info.into(), style));
399 }
400 }
401}
402
403impl Drop for StepProgress {
404 fn drop(&mut self) {
405 self.progress.finish(self.step, &self.example_name);
406 }
407}
408
409struct ProgressLogger;
410
411impl Log for ProgressLogger {
412 fn enabled(&self, metadata: &Metadata) -> bool {
413 metadata.level() <= Level::Info
414 }
415
416 fn log(&self, record: &Record) {
417 if !self.enabled(record.metadata()) {
418 return;
419 }
420
421 let level_color = match record.level() {
422 Level::Error => "\x1b[31m",
423 Level::Warn => "\x1b[33m",
424 Level::Info => "\x1b[32m",
425 Level::Debug => "\x1b[34m",
426 Level::Trace => "\x1b[35m",
427 };
428 let reset = "\x1b[0m";
429 let bold = "\x1b[1m";
430
431 let level_label = match record.level() {
432 Level::Error => "Error",
433 Level::Warn => "Warn",
434 Level::Info => "Info",
435 Level::Debug => "Debug",
436 Level::Trace => "Trace",
437 };
438
439 let message = format!(
440 "{bold}{level_color}{level_label:>12}{reset} {}",
441 record.args()
442 );
443
444 if let Some(progress) = GLOBAL.get() {
445 progress.log(&message);
446 } else {
447 eprintln!("{}", message);
448 }
449 }
450
451 fn flush(&self) {
452 let _ = std::io::stderr().flush();
453 }
454}
455
456#[cfg(unix)]
457fn get_terminal_width() -> usize {
458 unsafe {
459 let mut winsize: libc::winsize = std::mem::zeroed();
460 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
461 && winsize.ws_col > 0
462 {
463 winsize.ws_col as usize
464 } else {
465 80
466 }
467 }
468}
469
470#[cfg(not(unix))]
471fn get_terminal_width() -> usize {
472 80
473}
474
475fn strip_ansi_len(s: &str) -> usize {
476 let mut len = 0;
477 let mut in_escape = false;
478 for c in s.chars() {
479 if c == '\x1b' {
480 in_escape = true;
481 } else if in_escape {
482 if c == 'm' {
483 in_escape = false;
484 }
485 } else {
486 len += 1;
487 }
488 }
489 len
490}
491
492fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
493 if s.len() <= max_len {
494 s.to_string()
495 } else {
496 format!("{}…", &s[..max_len.saturating_sub(1)])
497 }
498}
499
500fn format_duration(duration: Duration) -> String {
501 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
502
503 let millis = duration.as_millis() as f32;
504 if millis < 1000.0 {
505 format!("{}ms", millis)
506 } else if millis < MINUTE_IN_MILLIS {
507 format!("{:.1}s", millis / 1_000.0)
508 } else {
509 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
510 }
511}