eval_utils.rs

  1//! Utilities for evaluation and benchmarking.
  2
  3use std::{
  4    collections::HashMap,
  5    sync::{Arc, mpsc},
  6};
  7
  8fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
  9    let passed_count = evaluated_count - failed_count;
 10    let passed_ratio = if evaluated_count == 0 {
 11        0.0
 12    } else {
 13        passed_count as f64 / evaluated_count as f64
 14    };
 15    println!(
 16        "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
 17        evaluated_count,
 18        iterations,
 19        passed_ratio * 100.0
 20    )
 21}
 22
 23#[derive(Clone, Debug, Eq, PartialEq)]
 24pub enum OutcomeKind {
 25    Passed,
 26    Failed,
 27    Error,
 28}
 29
 30pub trait EvalOutputProcessor {
 31    type Metadata: 'static + Send;
 32    fn process(&mut self, output: &EvalOutput<Self::Metadata>);
 33    fn assert(&mut self);
 34}
 35
 36#[derive(Clone, Debug)]
 37pub struct EvalOutput<M> {
 38    pub outcome: OutcomeKind,
 39    pub data: String,
 40    pub metadata: M,
 41}
 42
 43pub struct NoProcessor;
 44impl EvalOutputProcessor for NoProcessor {
 45    type Metadata = ();
 46
 47    fn process(&mut self, _output: &EvalOutput<Self::Metadata>) {}
 48
 49    fn assert(&mut self) {}
 50}
 51
 52pub fn eval<P>(
 53    iterations: usize,
 54    expected_pass_ratio: f32,
 55    mut processor: P,
 56    evalf: impl Fn() -> EvalOutput<P::Metadata> + Send + Sync + 'static,
 57) where
 58    P: EvalOutputProcessor,
 59{
 60    let mut evaluated_count = 0;
 61    let mut failed_count = 0;
 62    let evalf = Arc::new(evalf);
 63    report_progress(evaluated_count, failed_count, iterations);
 64
 65    let (tx, rx) = mpsc::channel();
 66
 67    let executor = gpui::background_executor();
 68    let semaphore = Arc::new(smol::lock::Semaphore::new(32));
 69    let evalf = Arc::new(evalf);
 70    // Warm the cache once
 71    let first_output = evalf();
 72    tx.send(first_output).ok();
 73
 74    for _ in 1..iterations {
 75        let tx = tx.clone();
 76        let semaphore = semaphore.clone();
 77        let evalf = evalf.clone();
 78        executor
 79            .spawn(async move {
 80                let _guard = semaphore.acquire().await;
 81                let output = evalf();
 82                tx.send(output).ok();
 83            })
 84            .detach();
 85    }
 86    drop(tx);
 87
 88    let mut failed_evals = Vec::new();
 89    let mut errored_evals = HashMap::new();
 90    while let Ok(output) = rx.recv() {
 91        processor.process(&output);
 92
 93        match output.outcome {
 94            OutcomeKind::Passed => {}
 95            OutcomeKind::Failed => {
 96                failed_count += 1;
 97                failed_evals.push(output);
 98            }
 99            OutcomeKind::Error => {
100                failed_count += 1;
101                *errored_evals.entry(output.data).or_insert(0) += 1;
102            }
103        }
104
105        evaluated_count += 1;
106        report_progress(evaluated_count, failed_count, iterations);
107    }
108
109    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
110    println!("Actual pass ratio: {}\n", actual_pass_ratio);
111    if actual_pass_ratio < expected_pass_ratio {
112        for (error, count) in errored_evals {
113            println!("Eval errored {} times. Error: {}", count, error);
114        }
115
116        for failed in failed_evals {
117            println!("Eval failed");
118            println!("{}", failed.data);
119        }
120
121        panic!(
122            "Actual pass ratio: {}\nExpected pass ratio: {}",
123            actual_pass_ratio, expected_pass_ratio
124        );
125    }
126
127    processor.assert();
128}