1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::{Arc, Mutex, OnceLock},
6 time::{Duration, Instant},
7};
8
9use log::{Level, Log, Metadata, Record};
10
11pub struct Progress {
12 inner: Mutex<ProgressInner>,
13}
14
15struct ProgressInner {
16 completed: Vec<CompletedTask>,
17 in_progress: HashMap<String, InProgressTask>,
18 is_tty: bool,
19 terminal_width: usize,
20 max_example_name_len: usize,
21 status_lines_displayed: usize,
22 total_examples: usize,
23 last_line_is_logging: bool,
24}
25
26#[derive(Clone)]
27struct InProgressTask {
28 step: Step,
29 started_at: Instant,
30 substatus: Option<String>,
31 info: Option<(String, InfoStyle)>,
32}
33
34struct CompletedTask {
35 step: Step,
36 example_name: String,
37 duration: Duration,
38 info: Option<(String, InfoStyle)>,
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
42pub enum Step {
43 LoadProject,
44 Context,
45 FormatPrompt,
46 Predict,
47 Score,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum InfoStyle {
52 Normal,
53 Warning,
54}
55
56impl Step {
57 pub fn label(&self) -> &'static str {
58 match self {
59 Step::LoadProject => "Load",
60 Step::Context => "Context",
61 Step::FormatPrompt => "Format",
62 Step::Predict => "Predict",
63 Step::Score => "Score",
64 }
65 }
66
67 fn color_code(&self) -> &'static str {
68 match self {
69 Step::LoadProject => "\x1b[33m",
70 Step::Context => "\x1b[35m",
71 Step::FormatPrompt => "\x1b[34m",
72 Step::Predict => "\x1b[32m",
73 Step::Score => "\x1b[31m",
74 }
75 }
76}
77
78static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
79static LOGGER: ProgressLogger = ProgressLogger;
80
81const RIGHT_MARGIN: usize = 4;
82const MAX_STATUS_LINES: usize = 10;
83
84impl Progress {
85 /// Returns the global Progress instance, initializing it if necessary.
86 pub fn global() -> Arc<Progress> {
87 GLOBAL
88 .get_or_init(|| {
89 let progress = Arc::new(Self {
90 inner: Mutex::new(ProgressInner {
91 completed: Vec::new(),
92 in_progress: HashMap::new(),
93 is_tty: std::io::stderr().is_terminal(),
94 terminal_width: get_terminal_width(),
95 max_example_name_len: 0,
96 status_lines_displayed: 0,
97 total_examples: 0,
98 last_line_is_logging: false,
99 }),
100 });
101 let _ = log::set_logger(&LOGGER);
102 log::set_max_level(log::LevelFilter::Error);
103 progress
104 })
105 .clone()
106 }
107
108 pub fn set_total_examples(&self, total: usize) {
109 let mut inner = self.inner.lock().unwrap();
110 inner.total_examples = total;
111 }
112
113 /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
114 /// This should be used for any output that needs to appear above the status lines.
115 fn log(&self, message: &str) {
116 let mut inner = self.inner.lock().unwrap();
117 Self::clear_status_lines(&mut inner);
118
119 if !inner.last_line_is_logging {
120 let reset = "\x1b[0m";
121 let dim = "\x1b[2m";
122 let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
123 eprintln!("{dim}{divider}{reset}");
124 inner.last_line_is_logging = true;
125 }
126
127 eprintln!("{}", message);
128 }
129
130 pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
131 let mut inner = self.inner.lock().unwrap();
132
133 Self::clear_status_lines(&mut inner);
134
135 inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
136 inner.in_progress.insert(
137 example_name.to_string(),
138 InProgressTask {
139 step,
140 started_at: Instant::now(),
141 substatus: None,
142 info: None,
143 },
144 );
145
146 Self::print_status_lines(&mut inner);
147
148 StepProgress {
149 progress: self.clone(),
150 step,
151 example_name: example_name.to_string(),
152 }
153 }
154
155 fn finish(&self, step: Step, example_name: &str) {
156 let mut inner = self.inner.lock().unwrap();
157
158 let Some(task) = inner.in_progress.remove(example_name) else {
159 return;
160 };
161
162 if task.step == step {
163 inner.completed.push(CompletedTask {
164 step: task.step,
165 example_name: example_name.to_string(),
166 duration: task.started_at.elapsed(),
167 info: task.info,
168 });
169
170 Self::clear_status_lines(&mut inner);
171 Self::print_logging_closing_divider(&mut inner);
172 Self::print_completed(&inner, inner.completed.last().unwrap());
173 Self::print_status_lines(&mut inner);
174 } else {
175 inner.in_progress.insert(example_name.to_string(), task);
176 }
177 }
178
179 fn print_logging_closing_divider(inner: &mut ProgressInner) {
180 if inner.last_line_is_logging {
181 let reset = "\x1b[0m";
182 let dim = "\x1b[2m";
183 let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
184 eprintln!("{dim}{divider}{reset}");
185 inner.last_line_is_logging = false;
186 }
187 }
188
189 fn clear_status_lines(inner: &mut ProgressInner) {
190 if inner.is_tty && inner.status_lines_displayed > 0 {
191 // Move up and clear each line we previously displayed
192 for _ in 0..inner.status_lines_displayed {
193 eprint!("\x1b[A\x1b[K");
194 }
195 let _ = std::io::stderr().flush();
196 inner.status_lines_displayed = 0;
197 }
198 }
199
200 fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
201 let duration = format_duration(task.duration);
202 let name_width = inner.max_example_name_len;
203
204 if inner.is_tty {
205 let reset = "\x1b[0m";
206 let bold = "\x1b[1m";
207 let dim = "\x1b[2m";
208
209 let yellow = "\x1b[33m";
210 let info_part = task
211 .info
212 .as_ref()
213 .map(|(s, style)| {
214 if *style == InfoStyle::Warning {
215 format!("{yellow}{s}{reset}")
216 } else {
217 s.to_string()
218 }
219 })
220 .unwrap_or_default();
221
222 let prefix = format!(
223 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
224 color = task.step.color_code(),
225 label = task.step.label(),
226 name = task.example_name,
227 );
228
229 let duration_with_margin = format!("{duration} ");
230 let padding_needed = inner
231 .terminal_width
232 .saturating_sub(RIGHT_MARGIN)
233 .saturating_sub(duration_with_margin.len())
234 .saturating_sub(strip_ansi_len(&prefix));
235 let padding = " ".repeat(padding_needed);
236
237 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
238 } else {
239 let info_part = task
240 .info
241 .as_ref()
242 .map(|(s, _)| format!(" | {}", s))
243 .unwrap_or_default();
244
245 eprintln!(
246 "{label:>12} {name:<name_width$}{info_part} {duration}",
247 label = task.step.label(),
248 name = task.example_name,
249 );
250 }
251 }
252
253 fn print_status_lines(inner: &mut ProgressInner) {
254 if !inner.is_tty || inner.in_progress.is_empty() {
255 inner.status_lines_displayed = 0;
256 return;
257 }
258
259 let reset = "\x1b[0m";
260 let bold = "\x1b[1m";
261 let dim = "\x1b[2m";
262
263 // Build the done/in-progress/total label
264 let done_count = inner.completed.len();
265 let in_progress_count = inner.in_progress.len();
266 let range_label = format!(
267 " {}/{}/{} ",
268 done_count, in_progress_count, inner.total_examples
269 );
270
271 // Print a divider line with range label aligned with timestamps
272 let range_visible_len = range_label.len();
273 let left_divider_len = inner
274 .terminal_width
275 .saturating_sub(RIGHT_MARGIN)
276 .saturating_sub(range_visible_len);
277 let left_divider = "─".repeat(left_divider_len);
278 let right_divider = "─".repeat(RIGHT_MARGIN);
279 eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
280
281 let mut tasks: Vec<_> = inner.in_progress.iter().collect();
282 tasks.sort_by_key(|(name, _)| *name);
283
284 let total_tasks = tasks.len();
285 let mut lines_printed = 0;
286
287 for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
288 let elapsed = format_duration(task.started_at.elapsed());
289 let substatus_part = task
290 .substatus
291 .as_ref()
292 .map(|s| truncate_with_ellipsis(s, 30))
293 .unwrap_or_default();
294
295 let step_label = task.step.label();
296 let step_color = task.step.color_code();
297 let name_width = inner.max_example_name_len;
298
299 let prefix = format!(
300 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
301 name = name,
302 );
303
304 let duration_with_margin = format!("{elapsed} ");
305 let padding_needed = inner
306 .terminal_width
307 .saturating_sub(RIGHT_MARGIN)
308 .saturating_sub(duration_with_margin.len())
309 .saturating_sub(strip_ansi_len(&prefix));
310 let padding = " ".repeat(padding_needed);
311
312 eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
313 lines_printed += 1;
314 }
315
316 // Show "+N more" on its own line if there are more tasks
317 if total_tasks > MAX_STATUS_LINES {
318 let remaining = total_tasks - MAX_STATUS_LINES;
319 eprintln!("{:>12} +{remaining} more", "");
320 lines_printed += 1;
321 }
322
323 inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
324 let _ = std::io::stderr().flush();
325 }
326
327 pub fn clear(&self) {
328 let mut inner = self.inner.lock().unwrap();
329 Self::clear_status_lines(&mut inner);
330 }
331}
332
333pub struct StepProgress {
334 progress: Arc<Progress>,
335 step: Step,
336 example_name: String,
337}
338
339impl StepProgress {
340 pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
341 let mut inner = self.progress.inner.lock().unwrap();
342 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
343 task.substatus = Some(substatus.into().into_owned());
344 Progress::clear_status_lines(&mut inner);
345 Progress::print_status_lines(&mut inner);
346 }
347 }
348
349 pub fn clear_substatus(&self) {
350 let mut inner = self.progress.inner.lock().unwrap();
351 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
352 task.substatus = None;
353 Progress::clear_status_lines(&mut inner);
354 Progress::print_status_lines(&mut inner);
355 }
356 }
357
358 pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
359 let mut inner = self.progress.inner.lock().unwrap();
360 if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
361 task.info = Some((info.into(), style));
362 }
363 }
364}
365
366impl Drop for StepProgress {
367 fn drop(&mut self) {
368 self.progress.finish(self.step, &self.example_name);
369 }
370}
371
372struct ProgressLogger;
373
374impl Log for ProgressLogger {
375 fn enabled(&self, metadata: &Metadata) -> bool {
376 metadata.level() <= Level::Info
377 }
378
379 fn log(&self, record: &Record) {
380 if !self.enabled(record.metadata()) {
381 return;
382 }
383
384 let level_color = match record.level() {
385 Level::Error => "\x1b[31m",
386 Level::Warn => "\x1b[33m",
387 Level::Info => "\x1b[32m",
388 Level::Debug => "\x1b[34m",
389 Level::Trace => "\x1b[35m",
390 };
391 let reset = "\x1b[0m";
392 let bold = "\x1b[1m";
393
394 let level_label = match record.level() {
395 Level::Error => "Error",
396 Level::Warn => "Warn",
397 Level::Info => "Info",
398 Level::Debug => "Debug",
399 Level::Trace => "Trace",
400 };
401
402 let message = format!(
403 "{bold}{level_color}{level_label:>12}{reset} {}",
404 record.args()
405 );
406
407 if let Some(progress) = GLOBAL.get() {
408 progress.log(&message);
409 } else {
410 eprintln!("{}", message);
411 }
412 }
413
414 fn flush(&self) {
415 let _ = std::io::stderr().flush();
416 }
417}
418
419#[cfg(unix)]
420fn get_terminal_width() -> usize {
421 unsafe {
422 let mut winsize: libc::winsize = std::mem::zeroed();
423 if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
424 && winsize.ws_col > 0
425 {
426 winsize.ws_col as usize
427 } else {
428 80
429 }
430 }
431}
432
433#[cfg(not(unix))]
434fn get_terminal_width() -> usize {
435 80
436}
437
438fn strip_ansi_len(s: &str) -> usize {
439 let mut len = 0;
440 let mut in_escape = false;
441 for c in s.chars() {
442 if c == '\x1b' {
443 in_escape = true;
444 } else if in_escape {
445 if c == 'm' {
446 in_escape = false;
447 }
448 } else {
449 len += 1;
450 }
451 }
452 len
453}
454
455fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
456 if s.len() <= max_len {
457 s.to_string()
458 } else {
459 format!("{}…", &s[..max_len.saturating_sub(1)])
460 }
461}
462
463fn format_duration(duration: Duration) -> String {
464 const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
465
466 let millis = duration.as_millis() as f32;
467 if millis < 1000.0 {
468 format!("{}ms", millis)
469 } else if millis < MINUTE_IN_MILLIS {
470 format!("{:.1}s", millis / 1_000.0)
471 } else {
472 format!("{:.1}m", millis / MINUTE_IN_MILLIS)
473 }
474}