add recall and precision to semantic index

KCaverly created

Change summary

crates/semantic_index/eval/gpt-engineer.json |   2 
crates/semantic_index/examples/eval.rs       | 212 ++++++++++++++++++---
2 files changed, 179 insertions(+), 35 deletions(-)

Detailed changes

crates/semantic_index/eval/gpt-engineer.json 🔗

@@ -1,5 +1,5 @@
 {
-  "repo": "https://github.com/AntonOsika/gpt_engineer.git",
+  "repo": "https://github.com/AntonOsika/gpt-engineer.git",
   "commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
   "assertions": [
     {

crates/semantic_index/examples/eval.rs 🔗

@@ -10,7 +10,7 @@ use rust_embed::RustEmbed;
 use semantic_index::embedding::OpenAIEmbeddings;
 use semantic_index::semantic_index_settings::SemanticIndexSettings;
 use semantic_index::{SearchResult, SemanticIndex};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
 use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
 use std::path::{Path, PathBuf};
 use std::sync::Arc;
@@ -43,7 +43,7 @@ impl AssetSource for Assets {
     }
 }
 
-#[derive(Deserialize, Clone)]
+#[derive(Deserialize, Clone, Serialize)]
 struct EvaluationQuery {
     query: String,
     matches: Vec<String>,
@@ -72,15 +72,6 @@ struct RepoEval {
     assertions: Vec<EvaluationQuery>,
 }
 
-struct EvaluationResults {
-    token_count: usize,
-    span_count: usize,
-    time_to_index: Duration,
-    time_to_search: Vec<Duration>,
-    ndcg: HashMap<usize, f32>,
-    map: HashMap<usize, f32>,
-}
-
 const TMP_REPO_PATH: &str = "eval_repos";
 
 fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
@@ -114,7 +105,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
     Ok(repo_evals)
 }
 
-fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
+fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<(String, PathBuf)> {
     let repo_name = Path::new(repo_eval.repo.as_str())
         .file_name()
         .unwrap()
@@ -146,7 +137,7 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
     repo.checkout_tree(&obj, None)?;
     repo.set_head_detached(obj.id())?;
 
-    Ok(clone_path)
+    Ok((repo_name, clone_path))
 }
 
 fn dcg(hits: Vec<usize>) -> f32 {
@@ -253,30 +244,165 @@ fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
     let mut rolling_map = 0.0;
     for (idx, h) in hits.into_iter().enumerate() {
         rolling_non_zero += h as f32;
-        rolling_map += rolling_non_zero / (idx + 1) as f32;
+        if h == 1 {
+            rolling_map += rolling_non_zero / (idx + 1) as f32;
+        }
         map_at_k.push(rolling_map / non_zero);
     }
 
     map_at_k
 }
 
+fn evaluate_mrr(hits: Vec<usize>) -> f32 {
+    for (idx, h) in hits.into_iter().enumerate() {
+        if h == 1 {
+            return 1.0 / (idx + 1) as f32;
+        }
+    }
+
+    return 0.0;
+}
+
 fn init_logger() {
     env_logger::init();
 }
 
+#[derive(Serialize)]
+struct QueryMetrics {
+    query: EvaluationQuery,
+    millis_to_search: Duration,
+    ndcg: Vec<f32>,
+    map: Vec<f32>,
+    mrr: f32,
+    hits: Vec<usize>,
+    precision: Vec<f32>,
+    recall: Vec<f32>,
+}
+
+#[derive(Serialize)]
+struct SummaryMetrics {
+    millis_to_search: f32,
+    ndcg: Vec<f32>,
+    map: Vec<f32>,
+    mrr: f32,
+    precision: Vec<f32>,
+    recall: Vec<f32>,
+}
+
+#[derive(Serialize)]
+struct RepoEvaluationMetrics {
+    millis_to_index: Duration,
+    query_metrics: Vec<QueryMetrics>,
+    repo_metrics: Option<SummaryMetrics>,
+}
+
+impl RepoEvaluationMetrics {
+    fn new(millis_to_index: Duration) -> Self {
+        RepoEvaluationMetrics {
+            millis_to_index,
+            query_metrics: Vec::new(),
+            repo_metrics: None,
+        }
+    }
+
+    fn save(&self, repo_name: String) -> Result<()> {
+        let results_string = serde_json::to_string(&self)?;
+        fs::write(format!("./{}_evaluation.json", repo_name), results_string)
+            .expect("Unable to write file");
+        Ok(())
+    }
+
+    fn summarize(&mut self) {
+        let l = self.query_metrics.len() as f32;
+        let millis_to_search: f32 = self
+            .query_metrics
+            .iter()
+            .map(|metrics| metrics.millis_to_search.as_millis())
+            .sum::<u128>() as f32
+            / l;
+
+        let mut ndcg_sum = vec![0.0; 10];
+        let mut map_sum = vec![0.0; 10];
+        let mut precision_sum = vec![0.0; 10];
+        let mut recall_sum = vec![0.0; 10];
+        let mut mmr_sum = 0.0;
+
+        for query_metric in self.query_metrics.iter() {
+            for (ndcg, query_ndcg) in ndcg_sum.iter_mut().zip(query_metric.ndcg.clone()) {
+                *ndcg += query_ndcg;
+            }
+
+            for (mapp, query_map) in map_sum.iter_mut().zip(query_metric.map.clone()) {
+                *mapp += query_map;
+            }
+
+            for (pre, query_pre) in precision_sum.iter_mut().zip(query_metric.precision.clone()) {
+                *pre += query_pre;
+            }
+
+            for (rec, query_rec) in recall_sum.iter_mut().zip(query_metric.recall.clone()) {
+                *rec += query_rec;
+            }
+
+            mmr_sum += query_metric.mrr;
+        }
+
+        let ndcg = ndcg_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
+        let map = map_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
+        let precision = precision_sum
+            .iter()
+            .map(|val| val / l)
+            .collect::<Vec<f32>>();
+        let recall = recall_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
+        let mrr = mmr_sum / l;
+
+        self.repo_metrics = Some(SummaryMetrics {
+            millis_to_search,
+            ndcg,
+            map,
+            mrr,
+            precision,
+            recall,
+        })
+    }
+}
+
+fn evaluate_precision(hits: Vec<usize>) -> Vec<f32> {
+    let mut rolling_hit: f32 = 0.0;
+    let mut precision = Vec::new();
+    for (idx, hit) in hits.into_iter().enumerate() {
+        rolling_hit += hit as f32;
+        precision.push(rolling_hit / ((idx as f32) + 1.0));
+    }
+
+    precision
+}
+
+fn evaluate_recall(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
+    let total_relevant = ideal.iter().sum::<usize>() as f32;
+    let mut recall = Vec::new();
+    let mut rolling_hit: f32 = 0.0;
+    for hit in hits {
+        rolling_hit += hit as f32;
+        recall.push(rolling_hit / total_relevant);
+    }
+
+    recall
+}
+
 async fn evaluate_repo(
+    repo_name: String,
     index: ModelHandle<SemanticIndex>,
     project: ModelHandle<Project>,
     query_matches: Vec<EvaluationQuery>,
     cx: &mut AsyncAppContext,
-) -> Result<()> {
+) -> Result<RepoEvaluationMetrics> {
     // Index Project
     let index_t0 = Instant::now();
     index
         .update(cx, |index, cx| index.index_project(project.clone(), cx))
         .await?;
-    let index_time = index_t0.elapsed();
-    println!("Time to Index: {:?}", index_time.as_millis());
+    let mut repo_metrics = RepoEvaluationMetrics::new(index_t0.elapsed());
 
     for query in query_matches {
         // Query each match in order
@@ -286,26 +412,45 @@ async fn evaluate_repo(
                 index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
             })
             .await?;
-        let search_time = search_t0.elapsed();
-        println!("Time to Search: {:?}", search_time.as_millis());
+        let millis_to_search = search_t0.elapsed();
 
         // Get Hits/Ideal
         let k = 10;
-        let (ideal, hits) = self::get_hits(query, search_results, k, cx);
+        let (ideal, hits) = self::get_hits(query.clone(), search_results, k, cx);
 
         // Evaluate ndcg@k, for k = 1, 3, 5, 10
-        let ndcg = evaluate_ndcg(hits.clone(), ideal);
-        println!("NDCG: {:?}", ndcg);
+        let ndcg = evaluate_ndcg(hits.clone(), ideal.clone());
 
         // Evaluate map@k, for k = 1, 3, 5, 10
-        let map = evaluate_map(hits);
-        println!("MAP: {:?}", map);
+        let map = evaluate_map(hits.clone());
+
+        // Evaluate mrr
+        let mrr = evaluate_mrr(hits.clone());
+
+        // Evaluate precision
+        let precision = evaluate_precision(hits.clone());
 
-        // Evaluate span count
-        // Evaluate token count
+        // Evaluate Recall
+        let recall = evaluate_recall(hits.clone(), ideal);
+
+        let query_metrics = QueryMetrics {
+            query,
+            millis_to_search,
+            ndcg,
+            map,
+            mrr,
+            hits,
+            precision,
+            recall,
+        };
+
+        repo_metrics.query_metrics.push(query_metrics);
     }
 
-    anyhow::Ok(())
+    repo_metrics.summarize();
+    repo_metrics.save(repo_name);
+
+    anyhow::Ok(repo_metrics)
 }
 
 fn main() {
@@ -367,12 +512,10 @@ fn main() {
                 for repo in repo_evals {
                     let cloned = clone_repo(repo.clone());
                     match cloned {
-                        Ok(clone_path) => {
-                            log::trace!(
+                        Ok((repo_name, clone_path)) => {
+                            println!(
                                 "Cloned {:?} @ {:?} into {:?}",
-                                repo.repo,
-                                repo.commit,
-                                &clone_path
+                                repo.repo, repo.commit, &clone_path
                             );
 
                             // Create Project
@@ -393,7 +536,8 @@ fn main() {
                                 })
                                 .await;
 
-                            evaluate_repo(
+                            let repo_metrics = evaluate_repo(
+                                repo_name,
                                 semantic_index.clone(),
                                 project,
                                 repo.assertions,
@@ -402,7 +546,7 @@ fn main() {
                             .await?;
                         }
                         Err(err) => {
-                            log::trace!("Error cloning: {:?}", err);
+                            println!("Error cloning: {:?}", err);
                         }
                     }
                 }