add ndcg@k to evaluate metrics

KCaverly created

Change summary

crates/semantic_index/eval/gpt-engineer.json | 44 ++++++------
crates/semantic_index/examples/eval.rs       | 78 ++++++++++++++++++---
2 files changed, 88 insertions(+), 34 deletions(-)

Detailed changes

crates/semantic_index/eval/gpt-engineer.json 🔗

@@ -1,5 +1,5 @@
 {
-  "repo": "https://github.com/AntonOsika/gpt-engineer.git",
+  "repo": "https://github.com/AntonOsika/gpt_engineer.git",
   "commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
   "assertions": [
     {
@@ -12,48 +12,48 @@
     {
       "query": "What version of the openai package is active?",
       "matches": [
-        "pyprojet.toml:14"
+        "pyproject.toml:14"
       ]
     },
     {
       "query": "Ask user for clarification",
       "matches": [
-        "gpt-engineer/steps.py:69"
+        "gpt_engineer/steps.py:69"
       ]
     },
     {
       "query": "generate tests for python code",
       "matches": [
-        "gpt-engineer/steps.py:153"
+        "gpt_engineer/steps.py:153"
       ]
     },
     {
       "query": "get item from database based on key",
       "matches": [
-        "gpt-engineer/db.py:42",
-        "gpt-engineer/db.py:68"
+        "gpt_engineer/db.py:42",
+        "gpt_engineer/db.py:68"
       ]
     },
     {
       "query": "prompt user to select files",
       "matches": [
-        "gpt-engineer/file_selector.py:171",
-        "gpt-engineer/file_selector.py:306",
-        "gpt-engineer/file_selector.py:289",
-        "gpt-engineer/file_selector.py:234"
+        "gpt_engineer/file_selector.py:171",
+        "gpt_engineer/file_selector.py:306",
+        "gpt_engineer/file_selector.py:289",
+        "gpt_engineer/file_selector.py:234"
       ]
     },
     {
       "query": "send to rudderstack",
       "matches": [
-        "gpt-engineer/collect.py:11",
-        "gpt-engineer/collect.py:38"
+        "gpt_engineer/collect.py:11",
+        "gpt_engineer/collect.py:38"
       ]
     },
     {
       "query": "parse code blocks from chat messages",
       "matches": [
-        "gpt-engineer/chat_to_files.py:10",
+        "gpt_engineer/chat_to_files.py:10",
         "docs/intro/chat_parsing.md:1"
       ]
     },
@@ -66,35 +66,35 @@
     {
       "query": "ask the user if the code ran successfully?",
       "matches": [
-        "gpt-engineer/learning.py:54"
+        "gpt_engineer/learning.py:54"
       ]
     },
     {
       "query": "how is consent granted by the user?",
       "matches": [
-        "gpt-engineer/learning.py:107",
-        "gpt-engineer/learning.py:130",
-        "gpt-engineer/learning.py:152"
+        "gpt_engineer/learning.py:107",
+        "gpt_engineer/learning.py:130",
+        "gpt_engineer/learning.py:152"
       ]
     },
     {
       "query": "what are all the different steps the agent can take?",
       "matches": [
         "docs/intro/steps_module.md:1",
-        "gpt-engineer/steps.py:391"
+        "gpt_engineer/steps.py:391"
       ]
     },
     {
       "query": "ask the user for clarification?",
       "matches": [
-        "gpt-engineer/steps.py:69"
+        "gpt_engineer/steps.py:69"
       ]
     },
     {
       "query": "what models are available?",
       "matches": [
-        "gpt-engineer/ai.py:315",
-        "gpt-engineer/ai.py:341",
+        "gpt_engineer/ai.py:315",
+        "gpt_engineer/ai.py:341",
         "docs/open-models.md:1"
       ]
     },
@@ -107,7 +107,7 @@
     {
       "query": "does the agent know how to fix code?",
       "matches": [
-        "gpt-engineer/steps.py:367"
+        "gpt_engineer/steps.py:367"
       ]
     }
   ]

crates/semantic_index/examples/eval.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use client::{self, UserStore};
 use collections::HashMap;
 use git2::{Object, Oid, Repository};
-use gpui::{AssetSource, AsyncAppContext, ModelHandle, Task};
+use gpui::{AppContext, AssetSource, AsyncAppContext, ModelHandle, Task};
 use language::LanguageRegistry;
 use node_runtime::RealNodeRuntime;
 use project::{Project, RealFs};
@@ -50,17 +50,14 @@ struct EvaluationQuery {
 }
 
 impl EvaluationQuery {
-    fn match_pairs(&self) -> Vec<(PathBuf, usize)> {
+    fn match_pairs(&self) -> Vec<(PathBuf, u32)> {
         let mut pairs = Vec::new();
         for match_identifier in self.matches.iter() {
             let mut match_parts = match_identifier.split(":");
 
             if let Some(file_path) = match_parts.next() {
                 if let Some(row_number) = match_parts.next() {
-                    pairs.push((
-                        PathBuf::from(file_path),
-                        row_number.parse::<usize>().unwrap(),
-                    ));
+                    pairs.push((PathBuf::from(file_path), row_number.parse::<u32>().unwrap()));
                 }
             }
         }
@@ -156,11 +153,15 @@ fn dcg(hits: Vec<usize>) -> f32 {
         result += *hit as f32 / (2.0 + idx as f32).log2();
     }
 
-    println!("DCG: {:?}", result);
     result
 }
 
-fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
+fn evaluate_ndcg(
+    eval_query: EvaluationQuery,
+    search_results: Vec<SearchResult>,
+    k: usize,
+    cx: &AsyncAppContext,
+) -> Vec<f32> {
     // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
     // items returned by the search engine relative to the hypothetical ideal.
     // Relevance is represented as a series of booleans, in which each search result returned
@@ -180,9 +181,58 @@ fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>,
     // very high quality, whereas rank results quickly drop off after the first result.
 
     let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
-    let hits = vec![1];
 
-    return dcg(hits) / dcg(ideal);
+    let mut hits = Vec::new();
+    for result in search_results {
+        let (path, start_row, end_row) = result.buffer.read_with(cx, |buffer, cx| {
+            let path = buffer.file().unwrap().path().to_path_buf();
+            let start_row = buffer.offset_to_point(result.range.start.offset).row;
+            let end_row = buffer.offset_to_point(result.range.end.offset).row;
+            (path, start_row, end_row)
+        });
+
+        let match_pairs = eval_query.match_pairs();
+        let mut found = 0;
+        for (match_path, match_row) in match_pairs {
+            if match_path == path {
+                if match_row >= start_row && match_row <= end_row {
+                    found = 1;
+                    break;
+                }
+            }
+        }
+
+        hits.push(found);
+    }
+
+    // For now, we are calculating ideal_hits a bit different, as technically
+    // with overlapping ranges, one match can result in more than result.
+    let mut ideal_hits = hits.clone();
+    ideal_hits.retain(|x| x == &1);
+
+    let ideal = if ideal.len() > ideal_hits.len() {
+        ideal
+    } else {
+        ideal_hits
+    };
+
+    // Fill ideal to 10 length
+    let mut filled_ideal = [0; 10];
+    for (idx, i) in ideal.to_vec().into_iter().enumerate() {
+        filled_ideal[idx] = i;
+    }
+
+    let mut ndcg = Vec::new();
+    for idx in 1..(hits.len() + 1) {
+        let hits_at_k = hits[0..idx].to_vec();
+        let ideal_at_k = filled_ideal[0..idx].to_vec();
+
+        let at_k = dcg(hits_at_k.clone()) / dcg(ideal_at_k.clone());
+
+        ndcg.push(at_k);
+    }
+
+    ndcg
 }
 
 // fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {}
@@ -209,14 +259,17 @@ async fn evaluate_repo(
         // Query each match in order
         let search_t0 = Instant::now();
         let search_results = index
-            .update(cx, |index, mut cx| {
-                index.search_project(project.clone(), query.query, 10, vec![], vec![], cx)
+            .update(cx, |index, cx| {
+                index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
             })
             .await?;
         let search_time = search_t0.elapsed();
         println!("Time to Search: {:?}", search_time.as_secs());
 
         // Evaluate ndcg@k, for k = 1, 3, 5, 10
+        let ndcg = evaluate_ndcg(query, search_results, 10, cx);
+        println!("NDCG: {:?}", ndcg);
+
         // Evaluate map@k, for k = 1, 3, 5, 10
         // Evaluate span count
         // Evaluate token count
@@ -259,6 +312,7 @@ fn main() {
 
         let node_runtime = RealNodeRuntime::new(http.clone());
         languages::init(languages.clone(), node_runtime.clone());
+        language::init(cx);
 
         project::Project::init(&client, cx);
         semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);