1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use log::{Level, Log, Metadata, Record};
10
11pub struct Progress {
12 inner: Mutex<ProgressInner>,
13}
14
15struct ProgressInner {
16 completed: Vec<CompletedTask>,
17 in_progress: HashMap<String, InProgressTask>,
18 is_tty: bool,
19 terminal_width: usize,
20 max_example_name_len: usize,
21 status_lines_displayed: usize,
22 total_examples: usize,
23 failed_examples: usize,
24 last_line_is_logging: bool,
25}
26
27#[derive(Clone)]
28struct InProgressTask {
29 step: Step,
30 started_at: Instant,
31 substatus: Option<String>,
32 info: Option<(String, InfoStyle)>,
33}
34
35struct CompletedTask {
36 step: Step,
37 example_name: String,
38 duration: Duration,
39 info: Option<(String, InfoStyle)>,
40}
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
43pub enum Step {
44 LoadProject,
45 Context,
46 FormatPrompt,
47 Predict,
48 Score,
49}
50
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub enum InfoStyle {
53 Normal,
54 Warning,
55}
56
57impl Step {
58 pub fn label(&self) -> &'static str {
59 match self {
60 Step::LoadProject => "Load",
61 Step::Context => "Context",
62 Step::FormatPrompt => "Format",
63 Step::Predict => "Predict",
64 Step::Score => "Score",
65 }
66 }
67
68 fn color_code(&self) -> &'static str {
69 match self {
70 Step::LoadProject => "\x1b[33m",
71 Step::Context => "\x1b[35m",
72 Step::FormatPrompt => "\x1b[34m",
73 Step::Predict => "\x1b[32m",
74 Step::Score => "\x1b[31m",
75 }
76 }
77}
78
79static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
80static LOGGER: ProgressLogger = ProgressLogger;
81
82const MARGIN: usize = 4;
83const MAX_STATUS_LINES: usize = 10;
84
85impl Progress {
86 /// Returns the global Progress instance, initializing it if necessary.
87 pub fn global() -> Arc<Progress> {
88 GLOBAL
89 .get_or_init(|| {
90 let progress = Arc::new(Self {
91 inner: Mutex::new(ProgressInner {
92 completed: Vec::new(),
93 in_progress: HashMap::new(),
94 is_tty: std::io::stderr().is_terminal(),
95 terminal_width: get_terminal_width(),
96 max_example_name_len: 0,
97 status_lines_displayed: 0,
98 total_examples: 0,
99 failed_examples: 0,
100 last_line_is_logging: false,
101 }),
102 });
103 let _ = log::set_logger(&LOGGER);
104 log::set_max_level(log::LevelFilter::Error);
105 progress
106 })
107 .clone()
108 }
109
110 pub fn set_total_examples(&self, total: usize) {
111 let mut inner = self.inner.lock().unwrap();
112 inner.total_examples = total;
113 }
114
115 pub fn increment_failed(&self) {
116 let mut inner = self.inner.lock().unwrap();
117 inner.failed_examples += 1;
118 }
119
120 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
121 /// This should be used for any output that needs to appear above the status lines.
122 fn log(&self, message: &str) {
123 let mut inner = self.inner.lock().unwrap();
124 Self::clear_status_lines(&mut inner);
125
126 if !inner.last_line_is_logging {
127 let reset = "\x1b[0m";
128 let dim = "\x1b[2m";
129 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
130 eprintln!("{dim}{divider}{reset}");
131 inner.last_line_is_logging = true;
132 }
133
134 eprintln!("{}", message);
135 }
136
137 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
138 let mut inner = self.inner.lock().unwrap();
139
140 Self::clear_status_lines(&mut inner);
141
142 inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
143 inner.in_progress.insert(
144 example_name.to_string(),
145 InProgressTask {
146 step,
147 started_at: Instant::now(),
148 substatus: None,
149 info: None,
150 },
151 );
152
153 Self::print_status_lines(&mut inner);
154
155 StepProgress {
156 progress: self.clone(),
157 step,
158 example_name: example_name.to_string(),
159 }
160 }
161
162 fn finish(&self, step: Step, example_name: &str) {
163 let mut inner = self.inner.lock().unwrap();
164
165 let Some(task) = inner.in_progress.remove(example_name) else {
166 return;
167 };
168
169 if task.step == step {
170 inner.completed.push(CompletedTask {
171 step: task.step,
172 example_name: example_name.to_string(),
173 duration: task.started_at.elapsed(),
174 info: task.info,
175 });
176
177 Self::clear_status_lines(&mut inner);
178 Self::print_logging_closing_divider(&mut inner);
179 Self::print_completed(&inner, inner.completed.last().unwrap());
180 Self::print_status_lines(&mut inner);
181 } else {
182 inner.in_progress.insert(example_name.to_string(), task);
183 }
184 }
185
186 fn print_logging_closing_divider(inner: &mut ProgressInner) {
187 if inner.last_line_is_logging {
188 let reset = "\x1b[0m";
189 let dim = "\x1b[2m";
190 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
191 eprintln!("{dim}{divider}{reset}");
192 inner.last_line_is_logging = false;
193 }
194 }
195
196 fn clear_status_lines(inner: &mut ProgressInner) {
197 if inner.is_tty && inner.status_lines_displayed > 0 {
198 // Move up and clear each line we previously displayed
199 for _ in 0..inner.status_lines_displayed {
200 eprint!("\x1b[A\x1b[K");
201 }
202 let _ = std::io::stderr().flush();
203 inner.status_lines_displayed = 0;
204 }
205 }
206
207 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
208 let duration = format_duration(task.duration);
209 let name_width = inner.max_example_name_len;
210
211 if inner.is_tty {
212 let reset = "\x1b[0m";
213 let bold = "\x1b[1m";
214 let dim = "\x1b[2m";
215
216 let yellow = "\x1b[33m";
217 let info_part = task
218 .info
219 .as_ref()
220 .map(|(s, style)| {
221 if *style == InfoStyle::Warning {
222 format!("{yellow}{s}{reset}")
223 } else {
224 s.to_string()
225 }
226 })
227 .unwrap_or_default();
228
229 let prefix = format!(
230 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
231 color = task.step.color_code(),
232 label = task.step.label(),
233 name = task.example_name,
234 );
235
236 let duration_with_margin = format!("{duration} ");
237 let padding_needed = inner
238 .terminal_width
239 .saturating_sub(MARGIN)
240 .saturating_sub(duration_with_margin.len())
241 .saturating_sub(strip_ansi_len(&prefix));
242 let padding = " ".repeat(padding_needed);
243
244 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
245 } else {
246 let info_part = task
247 .info
248 .as_ref()
249 .map(|(s, _)| format!(" | {}", s))
250 .unwrap_or_default();
251
252 eprintln!(
253 "{label:>12} {name:<name_width$}{info_part} {duration}",
254 label = task.step.label(),
255 name = task.example_name,
256 );
257 }
258 }
259
260 fn print_status_lines(inner: &mut ProgressInner) {
261 if !inner.is_tty || inner.in_progress.is_empty() {
262 inner.status_lines_displayed = 0;
263 return;
264 }
265
266 let reset = "\x1b[0m";
267 let bold = "\x1b[1m";
268 let dim = "\x1b[2m";
269
270 // Build the done/in-progress/total label
271 let done_count = inner.completed.len();
272 let in_progress_count = inner.in_progress.len();
273 let failed_count = inner.failed_examples;
274
275 let failed_label = if failed_count > 0 {
276 format!(" {} failed ", failed_count)
277 } else {
278 String::new()
279 };
280
281 let range_label = format!(
282 " {}/{}/{} ",
283 done_count, in_progress_count, inner.total_examples
284 );
285
286 // Print a divider line with failed count on left, range label on right
287 let failed_visible_len = strip_ansi_len(&failed_label);
288 let range_visible_len = range_label.len();
289 let middle_divider_len = inner
290 .terminal_width
291 .saturating_sub(MARGIN * 2)
292 .saturating_sub(failed_visible_len)
293 .saturating_sub(range_visible_len);
294 let left_divider = "─".repeat(MARGIN);
295 let middle_divider = "─".repeat(middle_divider_len);
296 let right_divider = "─".repeat(MARGIN);
297 eprintln!(
298 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
299 );
300
301 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
302 tasks.sort_by_key(|(name, _)| *name);
303
304 let total_tasks = tasks.len();
305 let mut lines_printed = 0;
306
307 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
308 let elapsed = format_duration(task.started_at.elapsed());
309 let substatus_part = task
310 .substatus
311 .as_ref()
312 .map(|s| truncate_with_ellipsis(s, 30))
313 .unwrap_or_default();
314
315 let step_label = task.step.label();
316 let step_color = task.step.color_code();
317 let name_width = inner.max_example_name_len;
318
319 let prefix = format!(
320 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
321 name = name,
322 );
323
324 let duration_with_margin = format!("{elapsed} ");
325 let padding_needed = inner
326 .terminal_width
327 .saturating_sub(MARGIN)
328 .saturating_sub(duration_with_margin.len())
329 .saturating_sub(strip_ansi_len(&prefix));
330 let padding = " ".repeat(padding_needed);
331
332 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
333 lines_printed += 1;
334 }
335
336 // Show "+N more" on its own line if there are more tasks
337 if total_tasks > MAX_STATUS_LINES {
338 let remaining = total_tasks - MAX_STATUS_LINES;
339 eprintln!("{:>12} +{remaining} more", "");
340 lines_printed += 1;
341 }
342
343 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
344 let _ = std::io::stderr().flush();
345 }
346
347 pub fn finalize(&self) {
348 let mut inner = self.inner.lock().unwrap();
349 Self::clear_status_lines(&mut inner);
350
351 // Print summary if there were failures
352 if inner.failed_examples > 0 {
353 let total_processed = inner.completed.len() + inner.failed_examples;
354 let percentage = if total_processed > 0 {
355 inner.failed_examples as f64 / total_processed as f64 * 100.0
356 } else {
357 0.0
358 };
359 eprintln!(
360 "\n{} of {} examples failed ({:.1}%)",
361 inner.failed_examples, total_processed, percentage
362 );
363 }
364 }
365}
366
367pub struct StepProgress {
368 progress: Arc<Progress>,
369 step: Step,
370 example_name: String,
371}
372
373impl StepProgress {
374 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
375 let mut inner = self.progress.inner.lock().unwrap();
376 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
377 task.substatus = Some(substatus.into().into_owned());
378 Progress::clear_status_lines(&mut inner);
379 Progress::print_status_lines(&mut inner);
380 }
381 }
382
383 pub fn clear_substatus(&self) {
384 let mut inner = self.progress.inner.lock().unwrap();
385 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
386 task.substatus = None;
387 Progress::clear_status_lines(&mut inner);
388 Progress::print_status_lines(&mut inner);
389 }
390 }
391
392 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
393 let mut inner = self.progress.inner.lock().unwrap();
394 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
395 task.info = Some((info.into(), style));
396 }
397 }
398}
399
400impl Drop for StepProgress {
401 fn drop(&mut self) {
402 self.progress.finish(self.step, &self.example_name);
403 }
404}
405
406struct ProgressLogger;
407
408impl Log for ProgressLogger {
409 fn enabled(&self, metadata: &Metadata) -> bool {
410 metadata.level() <= Level::Info
411 }
412
413 fn log(&self, record: &Record) {
414 if !self.enabled(record.metadata()) {
415 return;
416 }
417
418 let level_color = match record.level() {
419 Level::Error => "\x1b[31m",
420 Level::Warn => "\x1b[33m",
421 Level::Info => "\x1b[32m",
422 Level::Debug => "\x1b[34m",
423 Level::Trace => "\x1b[35m",
424 };
425 let reset = "\x1b[0m";
426 let bold = "\x1b[1m";
427
428 let level_label = match record.level() {
429 Level::Error => "Error",
430 Level::Warn => "Warn",
431 Level::Info => "Info",
432 Level::Debug => "Debug",
433 Level::Trace => "Trace",
434 };
435
436 let message = format!(
437 "{bold}{level_color}{level_label:>12}{reset} {}",
438 record.args()
439 );
440
441 if let Some(progress) = GLOBAL.get() {
442 progress.log(&message);
443 } else {
444 eprintln!("{}", message);
445 }
446 }
447
448 fn flush(&self) {
449 let _ = std::io::stderr().flush();
450 }
451}
452
453#[cfg(unix)]
454fn get_terminal_width() -> usize {
455 unsafe {
456 let mut winsize: libc::winsize = std::mem::zeroed();
457 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
458 && winsize.ws_col > 0
459 {
460 winsize.ws_col as usize
461 } else {
462 80
463 }
464 }
465}
466
467#[cfg(not(unix))]
468fn get_terminal_width() -> usize {
469 80
470}
471
472fn strip_ansi_len(s: &str) -> usize {
473 let mut len = 0;
474 let mut in_escape = false;
475 for c in s.chars() {
476 if c == '\x1b' {
477 in_escape = true;
478 } else if in_escape {
479 if c == 'm' {
480 in_escape = false;
481 }
482 } else {
483 len += 1;
484 }
485 }
486 len
487}
488
489fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
490 if s.len() <= max_len {
491 s.to_string()
492 } else {
493 format!("{}…", &s[..max_len.saturating_sub(1)])
494 }
495}
496
497fn format_duration(duration: Duration) -> String {
498 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
499
500 let millis = duration.as_millis() as f32;
501 if millis < 1000.0 {
502 format!("{}ms", millis)
503 } else if millis < MINUTE_IN_MILLIS {
504 format!("{:.1}s", millis / 1_000.0)
505 } else {
506 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
507 }
508}