eval_utils.rs

  1//! Utilities for evaluation and benchmarking.
  2
  3use std::{
  4    collections::HashMap,
  5    sync::{Arc, mpsc},
  6};
  7
  8fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
  9    let passed_count = evaluated_count - failed_count;
 10    let passed_ratio = if evaluated_count == 0 {
 11        0.0
 12    } else {
 13        passed_count as f64 / evaluated_count as f64
 14    };
 15    println!(
 16        "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
 17        evaluated_count,
 18        iterations,
 19        passed_ratio * 100.0
 20    )
 21}
 22
 23#[derive(Clone, Debug, Eq, PartialEq)]
 24pub enum OutcomeKind {
 25    Passed,
 26    Failed,
 27    Error,
 28}
 29
 30pub trait EvalOutputProcessor {
 31    type Metadata: 'static + Send;
 32    fn process(&mut self, output: &EvalOutput<Self::Metadata>);
 33    fn assert(&mut self);
 34}
 35
 36#[derive(Clone, Debug)]
 37pub struct EvalOutput<M> {
 38    pub outcome: OutcomeKind,
 39    pub data: String,
 40    pub metadata: M,
 41}
 42
 43impl<M: Default> EvalOutput<M> {
 44    pub fn passed(message: impl Into<String>) -> Self {
 45        EvalOutput {
 46            outcome: OutcomeKind::Passed,
 47            data: message.into(),
 48            metadata: M::default(),
 49        }
 50    }
 51
 52    pub fn failed(message: impl Into<String>) -> Self {
 53        EvalOutput {
 54            outcome: OutcomeKind::Failed,
 55            data: message.into(),
 56            metadata: M::default(),
 57        }
 58    }
 59}
 60
 61pub struct NoProcessor;
 62impl EvalOutputProcessor for NoProcessor {
 63    type Metadata = ();
 64
 65    fn process(&mut self, _output: &EvalOutput<Self::Metadata>) {}
 66
 67    fn assert(&mut self) {}
 68}
 69
 70pub fn eval<P>(
 71    iterations: usize,
 72    expected_pass_ratio: f32,
 73    mut processor: P,
 74    evalf: impl Fn() -> EvalOutput<P::Metadata> + Send + Sync + 'static,
 75) where
 76    P: EvalOutputProcessor,
 77{
 78    let mut evaluated_count = 0;
 79    let mut failed_count = 0;
 80    let evalf = Arc::new(evalf);
 81    report_progress(evaluated_count, failed_count, iterations);
 82
 83    let (tx, rx) = mpsc::channel();
 84
 85    let executor = gpui::background_executor();
 86    let semaphore = Arc::new(smol::lock::Semaphore::new(32));
 87    let evalf = Arc::new(evalf);
 88    // Warm the cache once
 89    let first_output = evalf();
 90    tx.send(first_output).ok();
 91
 92    for _ in 1..iterations {
 93        let tx = tx.clone();
 94        let semaphore = semaphore.clone();
 95        let evalf = evalf.clone();
 96        executor
 97            .spawn(async move {
 98                let _guard = semaphore.acquire().await;
 99                let output = evalf();
100                tx.send(output).ok();
101            })
102            .detach();
103    }
104    drop(tx);
105
106    let mut failed_evals = Vec::new();
107    let mut errored_evals = HashMap::new();
108    while let Ok(output) = rx.recv() {
109        processor.process(&output);
110
111        match output.outcome {
112            OutcomeKind::Passed => {}
113            OutcomeKind::Failed => {
114                failed_count += 1;
115                failed_evals.push(output);
116            }
117            OutcomeKind::Error => {
118                failed_count += 1;
119                *errored_evals.entry(output.data).or_insert(0) += 1;
120            }
121        }
122
123        evaluated_count += 1;
124        report_progress(evaluated_count, failed_count, iterations);
125    }
126
127    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
128    println!("Actual pass ratio: {}\n", actual_pass_ratio);
129    if actual_pass_ratio < expected_pass_ratio {
130        for (error, count) in errored_evals {
131            println!("Eval errored {} times. Error: {}", count, error);
132        }
133
134        for failed in failed_evals {
135            println!("Eval failed");
136            println!("{}", failed.data);
137        }
138
139        panic!(
140            "Actual pass ratio: {}\nExpected pass ratio: {}",
141            actual_pass_ratio, expected_pass_ratio
142        );
143    }
144
145    processor.assert();
146}