@@ -10,7 +10,7 @@ use rust_embed::RustEmbed;
use semantic_index::embedding::OpenAIEmbeddings;
use semantic_index::semantic_index_settings::SemanticIndexSettings;
use semantic_index::{SearchResult, SemanticIndex};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
use std::path::{Path, PathBuf};
use std::sync::Arc;
@@ -43,7 +43,7 @@ impl AssetSource for Assets {
}
}
-#[derive(Deserialize, Clone)]
+#[derive(Deserialize, Clone, Serialize)]
struct EvaluationQuery {
query: String,
matches: Vec<String>,
@@ -72,15 +72,6 @@ struct RepoEval {
assertions: Vec<EvaluationQuery>,
}
-struct EvaluationResults {
- token_count: usize,
- span_count: usize,
- time_to_index: Duration,
- time_to_search: Vec<Duration>,
- ndcg: HashMap<usize, f32>,
- map: HashMap<usize, f32>,
-}
-
const TMP_REPO_PATH: &str = "eval_repos";
fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
@@ -114,7 +105,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
Ok(repo_evals)
}
-fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
+fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<(String, PathBuf)> {
let repo_name = Path::new(repo_eval.repo.as_str())
.file_name()
.unwrap()
@@ -146,7 +137,7 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
repo.checkout_tree(&obj, None)?;
repo.set_head_detached(obj.id())?;
- Ok(clone_path)
+ Ok((repo_name, clone_path))
}
fn dcg(hits: Vec<usize>) -> f32 {
@@ -253,30 +244,165 @@ fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
let mut rolling_map = 0.0;
for (idx, h) in hits.into_iter().enumerate() {
rolling_non_zero += h as f32;
- rolling_map += rolling_non_zero / (idx + 1) as f32;
+ if h == 1 {
+ rolling_map += rolling_non_zero / (idx + 1) as f32;
+ }
map_at_k.push(rolling_map / non_zero);
}
map_at_k
}
+fn evaluate_mrr(hits: Vec<usize>) -> f32 {
+ for (idx, h) in hits.into_iter().enumerate() {
+ if h == 1 {
+ return 1.0 / (idx + 1) as f32;
+ }
+ }
+
+ return 0.0;
+}
+
fn init_logger() {
env_logger::init();
}
+#[derive(Serialize)]
+struct QueryMetrics {
+ query: EvaluationQuery,
+ millis_to_search: Duration,
+ ndcg: Vec<f32>,
+ map: Vec<f32>,
+ mrr: f32,
+ hits: Vec<usize>,
+ precision: Vec<f32>,
+ recall: Vec<f32>,
+}
+
+#[derive(Serialize)]
+struct SummaryMetrics {
+ millis_to_search: f32,
+ ndcg: Vec<f32>,
+ map: Vec<f32>,
+ mrr: f32,
+ precision: Vec<f32>,
+ recall: Vec<f32>,
+}
+
+#[derive(Serialize)]
+struct RepoEvaluationMetrics {
+ millis_to_index: Duration,
+ query_metrics: Vec<QueryMetrics>,
+ repo_metrics: Option<SummaryMetrics>,
+}
+
+impl RepoEvaluationMetrics {
+ fn new(millis_to_index: Duration) -> Self {
+ RepoEvaluationMetrics {
+ millis_to_index,
+ query_metrics: Vec::new(),
+ repo_metrics: None,
+ }
+ }
+
+ fn save(&self, repo_name: String) -> Result<()> {
+ let results_string = serde_json::to_string(&self)?;
+ fs::write(format!("./{}_evaluation.json", repo_name), results_string)
+ .expect("Unable to write file");
+ Ok(())
+ }
+
+ fn summarize(&mut self) {
+ let l = self.query_metrics.len() as f32;
+ let millis_to_search: f32 = self
+ .query_metrics
+ .iter()
+ .map(|metrics| metrics.millis_to_search.as_millis())
+ .sum::<u128>() as f32
+ / l;
+
+ let mut ndcg_sum = vec![0.0; 10];
+ let mut map_sum = vec![0.0; 10];
+ let mut precision_sum = vec![0.0; 10];
+ let mut recall_sum = vec![0.0; 10];
+ let mut mmr_sum = 0.0;
+
+ for query_metric in self.query_metrics.iter() {
+ for (ndcg, query_ndcg) in ndcg_sum.iter_mut().zip(query_metric.ndcg.clone()) {
+ *ndcg += query_ndcg;
+ }
+
+ for (mapp, query_map) in map_sum.iter_mut().zip(query_metric.map.clone()) {
+ *mapp += query_map;
+ }
+
+ for (pre, query_pre) in precision_sum.iter_mut().zip(query_metric.precision.clone()) {
+ *pre += query_pre;
+ }
+
+ for (rec, query_rec) in recall_sum.iter_mut().zip(query_metric.recall.clone()) {
+ *rec += query_rec;
+ }
+
+ mmr_sum += query_metric.mrr;
+ }
+
+ let ndcg = ndcg_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
+ let map = map_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
+ let precision = precision_sum
+ .iter()
+ .map(|val| val / l)
+ .collect::<Vec<f32>>();
+ let recall = recall_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
+ let mrr = mmr_sum / l;
+
+ self.repo_metrics = Some(SummaryMetrics {
+ millis_to_search,
+ ndcg,
+ map,
+ mrr,
+ precision,
+ recall,
+ })
+ }
+}
+
+fn evaluate_precision(hits: Vec<usize>) -> Vec<f32> {
+ let mut rolling_hit: f32 = 0.0;
+ let mut precision = Vec::new();
+ for (idx, hit) in hits.into_iter().enumerate() {
+ rolling_hit += hit as f32;
+ precision.push(rolling_hit / ((idx as f32) + 1.0));
+ }
+
+ precision
+}
+
+fn evaluate_recall(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
+ let total_relevant = ideal.iter().sum::<usize>() as f32;
+ let mut recall = Vec::new();
+ let mut rolling_hit: f32 = 0.0;
+ for hit in hits {
+ rolling_hit += hit as f32;
+ recall.push(rolling_hit / total_relevant);
+ }
+
+ recall
+}
+
async fn evaluate_repo(
+ repo_name: String,
index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>,
query_matches: Vec<EvaluationQuery>,
cx: &mut AsyncAppContext,
-) -> Result<()> {
+) -> Result<RepoEvaluationMetrics> {
// Index Project
let index_t0 = Instant::now();
index
.update(cx, |index, cx| index.index_project(project.clone(), cx))
.await?;
- let index_time = index_t0.elapsed();
- println!("Time to Index: {:?}", index_time.as_millis());
+ let mut repo_metrics = RepoEvaluationMetrics::new(index_t0.elapsed());
for query in query_matches {
// Query each match in order
@@ -286,26 +412,45 @@ async fn evaluate_repo(
index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
})
.await?;
- let search_time = search_t0.elapsed();
- println!("Time to Search: {:?}", search_time.as_millis());
+ let millis_to_search = search_t0.elapsed();
// Get Hits/Ideal
let k = 10;
- let (ideal, hits) = self::get_hits(query, search_results, k, cx);
+ let (ideal, hits) = self::get_hits(query.clone(), search_results, k, cx);
// Evaluate ndcg@k, for k = 1, 3, 5, 10
- let ndcg = evaluate_ndcg(hits.clone(), ideal);
- println!("NDCG: {:?}", ndcg);
+ let ndcg = evaluate_ndcg(hits.clone(), ideal.clone());
// Evaluate map@k, for k = 1, 3, 5, 10
- let map = evaluate_map(hits);
- println!("MAP: {:?}", map);
+ let map = evaluate_map(hits.clone());
+
+ // Evaluate mrr
+ let mrr = evaluate_mrr(hits.clone());
+
+ // Evaluate precision
+ let precision = evaluate_precision(hits.clone());
- // Evaluate span count
- // Evaluate token count
+ // Evaluate Recall
+ let recall = evaluate_recall(hits.clone(), ideal);
+
+ let query_metrics = QueryMetrics {
+ query,
+ millis_to_search,
+ ndcg,
+ map,
+ mrr,
+ hits,
+ precision,
+ recall,
+ };
+
+ repo_metrics.query_metrics.push(query_metrics);
}
- anyhow::Ok(())
+ repo_metrics.summarize();
+ repo_metrics.save(repo_name);
+
+ anyhow::Ok(repo_metrics)
}
fn main() {
@@ -367,12 +512,10 @@ fn main() {
for repo in repo_evals {
let cloned = clone_repo(repo.clone());
match cloned {
- Ok(clone_path) => {
- log::trace!(
+ Ok((repo_name, clone_path)) => {
+ println!(
"Cloned {:?} @ {:?} into {:?}",
- repo.repo,
- repo.commit,
- &clone_path
+ repo.repo, repo.commit, &clone_path
);
// Create Project
@@ -393,7 +536,8 @@ fn main() {
})
.await;
- evaluate_repo(
+ let repo_metrics = evaluate_repo(
+ repo_name,
semantic_index.clone(),
project,
repo.assertions,
@@ -402,7 +546,7 @@ fn main() {
.await?;
}
Err(err) => {
- log::trace!("Error cloning: {:?}", err);
+ println!("Error cloning: {:?}", err);
}
}
}