1//! Utilities for evaluation and benchmarking.
2
3use std::{
4 collections::HashMap,
5 sync::{Arc, mpsc},
6};
7
8fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
9 let passed_count = evaluated_count - failed_count;
10 let passed_ratio = if evaluated_count == 0 {
11 0.0
12 } else {
13 passed_count as f64 / evaluated_count as f64
14 };
15 println!(
16 "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
17 evaluated_count,
18 iterations,
19 passed_ratio * 100.0
20 )
21}
22
23#[derive(Clone, Debug, Eq, PartialEq)]
24pub enum OutcomeKind {
25 Passed,
26 Failed,
27 Error,
28}
29
30pub trait EvalOutputProcessor {
31 type Metadata: 'static + Send;
32 fn process(&mut self, output: &EvalOutput<Self::Metadata>);
33 fn assert(&mut self);
34}
35
36#[derive(Clone, Debug)]
37pub struct EvalOutput<M> {
38 pub outcome: OutcomeKind,
39 pub data: String,
40 pub metadata: M,
41}
42
43pub struct NoProcessor;
44impl EvalOutputProcessor for NoProcessor {
45 type Metadata = ();
46
47 fn process(&mut self, _output: &EvalOutput<Self::Metadata>) {}
48
49 fn assert(&mut self) {}
50}
51
52pub fn eval<P>(
53 iterations: usize,
54 expected_pass_ratio: f32,
55 mut processor: P,
56 evalf: impl Fn() -> EvalOutput<P::Metadata> + Send + Sync + 'static,
57) where
58 P: EvalOutputProcessor,
59{
60 let mut evaluated_count = 0;
61 let mut failed_count = 0;
62 let evalf = Arc::new(evalf);
63 report_progress(evaluated_count, failed_count, iterations);
64
65 let (tx, rx) = mpsc::channel();
66
67 let executor = gpui::background_executor();
68 let semaphore = Arc::new(smol::lock::Semaphore::new(32));
69 let evalf = Arc::new(evalf);
70 // Warm the cache once
71 let first_output = evalf();
72 tx.send(first_output).ok();
73
74 for _ in 1..iterations {
75 let tx = tx.clone();
76 let semaphore = semaphore.clone();
77 let evalf = evalf.clone();
78 executor
79 .spawn(async move {
80 let _guard = semaphore.acquire().await;
81 let output = evalf();
82 tx.send(output).ok();
83 })
84 .detach();
85 }
86 drop(tx);
87
88 let mut failed_evals = Vec::new();
89 let mut errored_evals = HashMap::new();
90 while let Ok(output) = rx.recv() {
91 processor.process(&output);
92
93 match output.outcome {
94 OutcomeKind::Passed => {}
95 OutcomeKind::Failed => {
96 failed_count += 1;
97 failed_evals.push(output);
98 }
99 OutcomeKind::Error => {
100 failed_count += 1;
101 *errored_evals.entry(output.data).or_insert(0) += 1;
102 }
103 }
104
105 evaluated_count += 1;
106 report_progress(evaluated_count, failed_count, iterations);
107 }
108
109 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
110 println!("Actual pass ratio: {}\n", actual_pass_ratio);
111 if actual_pass_ratio < expected_pass_ratio {
112 for (error, count) in errored_evals {
113 println!("Eval errored {} times. Error: {}", count, error);
114 }
115
116 for failed in failed_evals {
117 println!("Eval failed");
118 println!("{}", failed.data);
119 }
120
121 panic!(
122 "Actual pass ratio: {}\nExpected pass ratio: {}",
123 actual_pass_ratio, expected_pass_ratio
124 );
125 }
126
127 processor.assert();
128}