1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use log::{Level, Log, Metadata, Record};
10
11pub struct Progress {
12 inner: Mutex<ProgressInner>,
13}
14
15struct ProgressInner {
16 completed: Vec<CompletedTask>,
17 in_progress: HashMap<String, InProgressTask>,
18 is_tty: bool,
19 terminal_width: usize,
20 max_example_name_len: usize,
21 status_lines_displayed: usize,
22 total_steps: usize,
23 failed_examples: usize,
24 last_line_is_logging: bool,
25 ticker: Option<std::thread::JoinHandle<()>>,
26}
27
28#[derive(Clone)]
29struct InProgressTask {
30 step: Step,
31 started_at: Instant,
32 substatus: Option<String>,
33 info: Option<(String, InfoStyle)>,
34}
35
36struct CompletedTask {
37 step: Step,
38 example_name: String,
39 duration: Duration,
40 info: Option<(String, InfoStyle)>,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
44pub enum Step {
45 LoadProject,
46 Context,
47 FormatPrompt,
48 Predict,
49 Score,
50 Synthesize,
51 PullExamples,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
55pub enum InfoStyle {
56 Normal,
57 Warning,
58}
59
60impl Step {
61 pub fn label(&self) -> &'static str {
62 match self {
63 Step::LoadProject => "Load",
64 Step::Context => "Context",
65 Step::FormatPrompt => "Format",
66 Step::Predict => "Predict",
67 Step::Score => "Score",
68 Step::Synthesize => "Synthesize",
69 Step::PullExamples => "Pull",
70 }
71 }
72
73 fn color_code(&self) -> &'static str {
74 match self {
75 Step::LoadProject => "\x1b[33m",
76 Step::Context => "\x1b[35m",
77 Step::FormatPrompt => "\x1b[34m",
78 Step::Predict => "\x1b[32m",
79 Step::Score => "\x1b[31m",
80 Step::Synthesize => "\x1b[36m",
81 Step::PullExamples => "\x1b[36m",
82 }
83 }
84}
85
86static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
87static LOGGER: ProgressLogger = ProgressLogger;
88
89const MARGIN: usize = 4;
90const MAX_STATUS_LINES: usize = 10;
91const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
92
93impl Progress {
94 /// Returns the global Progress instance, initializing it if necessary.
95 pub fn global() -> Arc<Progress> {
96 GLOBAL
97 .get_or_init(|| {
98 let progress = Arc::new(Self {
99 inner: Mutex::new(ProgressInner {
100 completed: Vec::new(),
101 in_progress: HashMap::new(),
102 is_tty: std::io::stderr().is_terminal(),
103 terminal_width: get_terminal_width(),
104 max_example_name_len: 0,
105 status_lines_displayed: 0,
106 total_steps: 0,
107 failed_examples: 0,
108 last_line_is_logging: false,
109 ticker: None,
110 }),
111 });
112 let _ = log::set_logger(&LOGGER);
113 log::set_max_level(log::LevelFilter::Error);
114 progress
115 })
116 .clone()
117 }
118
119 pub fn set_total_steps(&self, total: usize) {
120 let mut inner = self.inner.lock().unwrap();
121 inner.total_steps = total;
122 }
123
124 pub fn increment_failed(&self) {
125 let mut inner = self.inner.lock().unwrap();
126 inner.failed_examples += 1;
127 }
128
129 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
130 /// This should be used for any output that needs to appear above the status lines.
131 fn log(&self, message: &str) {
132 let mut inner = self.inner.lock().unwrap();
133 Self::clear_status_lines(&mut inner);
134
135 if !inner.last_line_is_logging {
136 let reset = "\x1b[0m";
137 let dim = "\x1b[2m";
138 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
139 eprintln!("{dim}{divider}{reset}");
140 inner.last_line_is_logging = true;
141 }
142
143 eprintln!("{}", message);
144 }
145
146 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
147 let mut inner = self.inner.lock().unwrap();
148
149 Self::clear_status_lines(&mut inner);
150
151 let max_name_width = inner
152 .terminal_width
153 .saturating_sub(MARGIN * 2)
154 .saturating_div(3)
155 .max(1);
156 inner.max_example_name_len = inner
157 .max_example_name_len
158 .max(example_name.len().min(max_name_width));
159 inner.in_progress.insert(
160 example_name.to_string(),
161 InProgressTask {
162 step,
163 started_at: Instant::now(),
164 substatus: None,
165 info: None,
166 },
167 );
168
169 if inner.is_tty && inner.ticker.is_none() {
170 let progress = self.clone();
171 inner.ticker = Some(std::thread::spawn(move || {
172 loop {
173 std::thread::sleep(STATUS_TICK_INTERVAL);
174
175 let mut inner = progress.inner.lock().unwrap();
176 if inner.in_progress.is_empty() {
177 break;
178 }
179
180 Progress::clear_status_lines(&mut inner);
181 Progress::print_status_lines(&mut inner);
182 }
183 }));
184 }
185
186 Self::print_status_lines(&mut inner);
187
188 StepProgress {
189 progress: self.clone(),
190 step,
191 example_name: example_name.to_string(),
192 }
193 }
194
195 fn finish(&self, step: Step, example_name: &str) {
196 let mut inner = self.inner.lock().unwrap();
197
198 let Some(task) = inner.in_progress.remove(example_name) else {
199 return;
200 };
201
202 if task.step == step {
203 inner.completed.push(CompletedTask {
204 step: task.step,
205 example_name: example_name.to_string(),
206 duration: task.started_at.elapsed(),
207 info: task.info,
208 });
209
210 Self::clear_status_lines(&mut inner);
211 Self::print_logging_closing_divider(&mut inner);
212 if let Some(last_completed) = inner.completed.last() {
213 Self::print_completed(&inner, last_completed);
214 }
215 Self::print_status_lines(&mut inner);
216 } else {
217 inner.in_progress.insert(example_name.to_string(), task);
218 }
219 }
220
221 fn print_logging_closing_divider(inner: &mut ProgressInner) {
222 if inner.last_line_is_logging {
223 let reset = "\x1b[0m";
224 let dim = "\x1b[2m";
225 let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
226 eprintln!("{dim}{divider}{reset}");
227 inner.last_line_is_logging = false;
228 }
229 }
230
231 fn clear_status_lines(inner: &mut ProgressInner) {
232 if inner.is_tty && inner.status_lines_displayed > 0 {
233 // Move up and clear each line we previously displayed
234 for _ in 0..inner.status_lines_displayed {
235 eprint!("\x1b[A\x1b[K");
236 }
237 let _ = std::io::stderr().flush();
238 inner.status_lines_displayed = 0;
239 }
240 }
241
242 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
243 let duration = format_duration(task.duration);
244 let name_width = inner.max_example_name_len;
245 let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
246
247 if inner.is_tty {
248 let reset = "\x1b[0m";
249 let bold = "\x1b[1m";
250 let dim = "\x1b[2m";
251
252 let yellow = "\x1b[33m";
253 let info_part = task
254 .info
255 .as_ref()
256 .map(|(s, style)| {
257 if *style == InfoStyle::Warning {
258 format!("{yellow}{s}{reset}")
259 } else {
260 s.to_string()
261 }
262 })
263 .unwrap_or_default();
264
265 let prefix = format!(
266 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
267 color = task.step.color_code(),
268 label = task.step.label(),
269 name = truncated_name,
270 );
271
272 let duration_with_margin = format!("{duration} ");
273 let padding_needed = inner
274 .terminal_width
275 .saturating_sub(MARGIN)
276 .saturating_sub(duration_with_margin.len())
277 .saturating_sub(strip_ansi_len(&prefix));
278 let padding = " ".repeat(padding_needed);
279
280 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
281 } else {
282 let info_part = task
283 .info
284 .as_ref()
285 .map(|(s, _)| format!(" | {}", s))
286 .unwrap_or_default();
287
288 eprintln!(
289 "{label:>12} {name:<name_width$}{info_part} {duration}",
290 label = task.step.label(),
291 name = truncate_with_ellipsis(&task.example_name, name_width),
292 );
293 }
294 }
295
296 fn print_status_lines(inner: &mut ProgressInner) {
297 if !inner.is_tty || inner.in_progress.is_empty() {
298 inner.status_lines_displayed = 0;
299 return;
300 }
301
302 let reset = "\x1b[0m";
303 let bold = "\x1b[1m";
304 let dim = "\x1b[2m";
305
306 // Build the done/in-progress/total label
307 let done_count = inner.completed.len();
308 let in_progress_count = inner.in_progress.len();
309 let failed_count = inner.failed_examples;
310
311 let failed_label = if failed_count > 0 {
312 format!(" {} failed ", failed_count)
313 } else {
314 String::new()
315 };
316
317 let range_label = format!(
318 " {}/{}/{} ",
319 done_count, in_progress_count, inner.total_steps
320 );
321
322 // Print a divider line with failed count on left, range label on right
323 let failed_visible_len = strip_ansi_len(&failed_label);
324 let range_visible_len = range_label.len();
325 let middle_divider_len = inner
326 .terminal_width
327 .saturating_sub(MARGIN * 2)
328 .saturating_sub(failed_visible_len)
329 .saturating_sub(range_visible_len);
330 let left_divider = "─".repeat(MARGIN);
331 let middle_divider = "─".repeat(middle_divider_len);
332 let right_divider = "─".repeat(MARGIN);
333 eprintln!(
334 "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
335 );
336
337 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
338 tasks.sort_by_key(|(name, _)| *name);
339
340 let total_tasks = tasks.len();
341 let mut lines_printed = 0;
342
343 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
344 let elapsed = format_duration(task.started_at.elapsed());
345 let substatus_part = task
346 .substatus
347 .as_ref()
348 .map(|s| truncate_with_ellipsis(s, 30))
349 .unwrap_or_default();
350
351 let step_label = task.step.label();
352 let step_color = task.step.color_code();
353 let name_width = inner.max_example_name_len;
354 let truncated_name = truncate_with_ellipsis(name, name_width);
355
356 let prefix = format!(
357 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
358 name = truncated_name,
359 );
360
361 let duration_with_margin = format!("{elapsed} ");
362 let padding_needed = inner
363 .terminal_width
364 .saturating_sub(MARGIN)
365 .saturating_sub(duration_with_margin.len())
366 .saturating_sub(strip_ansi_len(&prefix));
367 let padding = " ".repeat(padding_needed);
368
369 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
370 lines_printed += 1;
371 }
372
373 // Show "+N more" on its own line if there are more tasks
374 if total_tasks > MAX_STATUS_LINES {
375 let remaining = total_tasks - MAX_STATUS_LINES;
376 eprintln!("{:>12} +{remaining} more", "");
377 lines_printed += 1;
378 }
379
380 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
381 let _ = std::io::stderr().flush();
382 }
383
384 pub fn finalize(&self) {
385 let ticker = {
386 let mut inner = self.inner.lock().unwrap();
387 inner.ticker.take()
388 };
389
390 if let Some(ticker) = ticker {
391 let _ = ticker.join();
392 }
393
394 let mut inner = self.inner.lock().unwrap();
395 Self::clear_status_lines(&mut inner);
396
397 // Print summary if there were failures
398 if inner.failed_examples > 0 {
399 let total_processed = inner.completed.len();
400 let percentage = if total_processed > 0 {
401 inner.failed_examples as f64 / total_processed as f64 * 100.0
402 } else {
403 0.0
404 };
405 eprintln!(
406 "\n{} of {} examples failed ({:.1}%)",
407 inner.failed_examples, total_processed, percentage
408 );
409 }
410 }
411}
412
413pub struct StepProgress {
414 progress: Arc<Progress>,
415 step: Step,
416 example_name: String,
417}
418
419impl StepProgress {
420 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
421 let mut inner = self.progress.inner.lock().unwrap();
422 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
423 task.substatus = Some(substatus.into().into_owned());
424 Progress::clear_status_lines(&mut inner);
425 Progress::print_status_lines(&mut inner);
426 }
427 }
428
429 pub fn clear_substatus(&self) {
430 let mut inner = self.progress.inner.lock().unwrap();
431 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
432 task.substatus = None;
433 Progress::clear_status_lines(&mut inner);
434 Progress::print_status_lines(&mut inner);
435 }
436 }
437
438 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
439 let mut inner = self.progress.inner.lock().unwrap();
440 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
441 task.info = Some((info.into(), style));
442 }
443 }
444}
445
446impl Drop for StepProgress {
447 fn drop(&mut self) {
448 self.progress.finish(self.step, &self.example_name);
449 }
450}
451
452struct ProgressLogger;
453
454impl Log for ProgressLogger {
455 fn enabled(&self, metadata: &Metadata) -> bool {
456 metadata.level() <= Level::Info
457 }
458
459 fn log(&self, record: &Record) {
460 if !self.enabled(record.metadata()) {
461 return;
462 }
463
464 let level_color = match record.level() {
465 Level::Error => "\x1b[31m",
466 Level::Warn => "\x1b[33m",
467 Level::Info => "\x1b[32m",
468 Level::Debug => "\x1b[34m",
469 Level::Trace => "\x1b[35m",
470 };
471 let reset = "\x1b[0m";
472 let bold = "\x1b[1m";
473
474 let level_label = match record.level() {
475 Level::Error => "Error",
476 Level::Warn => "Warn",
477 Level::Info => "Info",
478 Level::Debug => "Debug",
479 Level::Trace => "Trace",
480 };
481
482 let message = format!(
483 "{bold}{level_color}{level_label:>12}{reset} {}",
484 record.args()
485 );
486
487 if let Some(progress) = GLOBAL.get() {
488 progress.log(&message);
489 } else {
490 eprintln!("{}", message);
491 }
492 }
493
494 fn flush(&self) {
495 let _ = std::io::stderr().flush();
496 }
497}
498
499#[cfg(unix)]
500fn get_terminal_width() -> usize {
501 unsafe {
502 let mut winsize: libc::winsize = std::mem::zeroed();
503 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
504 && winsize.ws_col > 0
505 {
506 winsize.ws_col as usize
507 } else {
508 80
509 }
510 }
511}
512
513#[cfg(not(unix))]
514fn get_terminal_width() -> usize {
515 80
516}
517
518fn strip_ansi_len(s: &str) -> usize {
519 let mut len = 0;
520 let mut in_escape = false;
521 for c in s.chars() {
522 if c == '\x1b' {
523 in_escape = true;
524 } else if in_escape {
525 if c == 'm' {
526 in_escape = false;
527 }
528 } else {
529 len += 1;
530 }
531 }
532 len
533}
534
535fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
536 if s.len() <= max_len {
537 s.to_string()
538 } else {
539 format!("{}…", &s[..max_len.saturating_sub(1)])
540 }
541}
542
543fn format_duration(duration: Duration) -> String {
544 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
545
546 let millis = duration.as_millis() as f32;
547 if millis < 1000.0 {
548 format!("{}ms", millis)
549 } else if millis < MINUTE_IN_MILLIS {
550 format!("{:.1}s", millis / 1_000.0)
551 } else {
552 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
553 }
554}