eval.rs

  1use anyhow::{anyhow, Result};
  2use client::{self, UserStore};
  3use gpui::{AsyncAppContext, ModelHandle, Task};
  4use language::LanguageRegistry;
  5use node_runtime::RealNodeRuntime;
  6use project::{Project, RealFs};
  7use semantic_index::embedding::OpenAIEmbeddings;
  8use semantic_index::semantic_index_settings::SemanticIndexSettings;
  9use semantic_index::{SearchResult, SemanticIndex};
 10use serde::{Deserialize, Serialize};
 11use settings::{default_settings, SettingsStore};
 12use std::path::{Path, PathBuf};
 13use std::process::Command;
 14use std::sync::Arc;
 15use std::time::{Duration, Instant};
 16use std::{cmp, env, fs};
 17use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 18use util::http::{self};
 19use util::paths::EMBEDDINGS_DIR;
 20use zed::languages;
 21
 22#[derive(Deserialize, Clone, Serialize)]
 23struct EvaluationQuery {
 24    query: String,
 25    matches: Vec<String>,
 26}
 27
 28impl EvaluationQuery {
 29    fn match_pairs(&self) -> Vec<(PathBuf, u32)> {
 30        let mut pairs = Vec::new();
 31        for match_identifier in self.matches.iter() {
 32            let mut match_parts = match_identifier.split(":");
 33
 34            if let Some(file_path) = match_parts.next() {
 35                if let Some(row_number) = match_parts.next() {
 36                    pairs.push((PathBuf::from(file_path), row_number.parse::<u32>().unwrap()));
 37                }
 38            }
 39        }
 40        pairs
 41    }
 42}
 43
 44#[derive(Deserialize, Clone)]
 45struct RepoEval {
 46    repo: String,
 47    commit: String,
 48    assertions: Vec<EvaluationQuery>,
 49}
 50
 51const TMP_REPO_PATH: &str = "eval_repos";
 52
 53fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
 54    let eval_folder = env::current_dir()?
 55        .as_path()
 56        .parent()
 57        .unwrap()
 58        .join("crates/semantic_index/eval");
 59
 60    let mut repo_evals: Vec<RepoEval> = Vec::new();
 61    for entry in fs::read_dir(eval_folder)? {
 62        let file_path = entry.unwrap().path();
 63        if let Some(extension) = file_path.extension() {
 64            if extension == "json" {
 65                if let Ok(file) = fs::read_to_string(file_path) {
 66                    let repo_eval = serde_json::from_str(file.as_str());
 67
 68                    match repo_eval {
 69                        Ok(repo_eval) => {
 70                            repo_evals.push(repo_eval);
 71                        }
 72                        Err(err) => {
 73                            println!("Err: {:?}", err);
 74                        }
 75                    }
 76                }
 77            }
 78        }
 79    }
 80
 81    Ok(repo_evals)
 82}
 83
 84fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<(String, PathBuf)> {
 85    let repo_name = Path::new(repo_eval.repo.as_str())
 86        .file_name()
 87        .unwrap()
 88        .to_str()
 89        .unwrap()
 90        .to_owned()
 91        .replace(".git", "");
 92
 93    let clone_path = fs::canonicalize(env::current_dir()?)?
 94        .parent()
 95        .ok_or(anyhow!("path canonicalization failed"))?
 96        .parent()
 97        .unwrap()
 98        .join(TMP_REPO_PATH);
 99
100    // Delete Clone Path if already exists
101    let _ = fs::remove_dir_all(&clone_path);
102    let _ = fs::create_dir(&clone_path);
103
104    let _ = Command::new("git")
105        .args(["clone", repo_eval.repo.as_str()])
106        .current_dir(clone_path.clone())
107        .output()?;
108    // Update clone path to be new directory housing the repo.
109    let clone_path = clone_path.join(repo_name.clone());
110    let _ = Command::new("git")
111        .args(["checkout", repo_eval.commit.as_str()])
112        .current_dir(clone_path.clone())
113        .output()?;
114
115    Ok((repo_name, clone_path))
116}
117
118fn dcg(hits: Vec<usize>) -> f32 {
119    let mut result = 0.0;
120    for (idx, hit) in hits.iter().enumerate() {
121        result += *hit as f32 / (2.0 + idx as f32).log2();
122    }
123
124    result
125}
126
127fn get_hits(
128    eval_query: EvaluationQuery,
129    search_results: Vec<SearchResult>,
130    k: usize,
131    cx: &AsyncAppContext,
132) -> (Vec<usize>, Vec<usize>) {
133    let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
134
135    let mut hits = Vec::new();
136    for result in search_results {
137        let (path, start_row, end_row) = result.buffer.read_with(cx, |buffer, _cx| {
138            let path = buffer.file().unwrap().path().to_path_buf();
139            let start_row = buffer.offset_to_point(result.range.start.offset).row;
140            let end_row = buffer.offset_to_point(result.range.end.offset).row;
141            (path, start_row, end_row)
142        });
143
144        let match_pairs = eval_query.match_pairs();
145        let mut found = 0;
146        for (match_path, match_row) in match_pairs {
147            if match_path == path {
148                if match_row >= start_row && match_row <= end_row {
149                    found = 1;
150                    break;
151                }
152            }
153        }
154
155        hits.push(found);
156    }
157
158    // For now, we are calculating ideal_hits a bit different, as technically
159    // with overlapping ranges, one match can result in more than result.
160    let mut ideal_hits = hits.clone();
161    ideal_hits.retain(|x| x == &1);
162
163    let ideal = if ideal.len() > ideal_hits.len() {
164        ideal
165    } else {
166        ideal_hits
167    };
168
169    // Fill ideal to 10 length
170    let mut filled_ideal = [0; 10];
171    for (idx, i) in ideal.to_vec().into_iter().enumerate() {
172        filled_ideal[idx] = i;
173    }
174
175    (filled_ideal.to_vec(), hits)
176}
177
178fn evaluate_ndcg(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
179    // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
180    // items returned by the search engine relative to the hypothetical ideal.
181    // Relevance is represented as a series of booleans, in which each search result returned
182    // is identified as being inside the test set of matches (1) or not (0).
183
184    // For example, if result 1, 3 and 5 match the 3 relevant results provided
185    // actual dcg is calculated against a vector of [1, 0, 1, 0, 1]
186    // whereas ideal dcg is calculated against a vector of [1, 1, 1, 0, 0]
187    // as this ideal vector assumes the 3 relevant results provided were returned first
188    // normalized dcg is then calculated as actual dcg / ideal dcg.
189
190    // NDCG ranges from 0 to 1, which higher values indicating better performance
191    // Commonly NDCG is expressed as NDCG@k, in which k represents the metric calculated
192    // including only the top k values returned.
193    // The @k metrics can help you identify, at what point does the relevant results start to fall off.
194    // Ie. a NDCG@1 of 0.9 and a NDCG@3 of 0.5 may indicate that the first result returned in usually
195    // very high quality, whereas rank results quickly drop off after the first result.
196
197    let mut ndcg = Vec::new();
198    for idx in 1..(hits.len() + 1) {
199        let hits_at_k = hits[0..idx].to_vec();
200        let ideal_at_k = ideal[0..idx].to_vec();
201
202        let at_k = dcg(hits_at_k.clone()) / dcg(ideal_at_k.clone());
203
204        ndcg.push(at_k);
205    }
206
207    ndcg
208}
209
210fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
211    let mut map_at_k = Vec::new();
212
213    let non_zero = hits.iter().sum::<usize>() as f32;
214    if non_zero == 0.0 {
215        return vec![0.0; hits.len()];
216    }
217
218    let mut rolling_non_zero = 0.0;
219    let mut rolling_map = 0.0;
220    for (idx, h) in hits.into_iter().enumerate() {
221        rolling_non_zero += h as f32;
222        if h == 1 {
223            rolling_map += rolling_non_zero / (idx + 1) as f32;
224        }
225        map_at_k.push(rolling_map / non_zero);
226    }
227
228    map_at_k
229}
230
231fn evaluate_mrr(hits: Vec<usize>) -> f32 {
232    for (idx, h) in hits.into_iter().enumerate() {
233        if h == 1 {
234            return 1.0 / (idx + 1) as f32;
235        }
236    }
237
238    return 0.0;
239}
240
241fn init_logger() {
242    env_logger::init();
243}
244
245#[derive(Serialize)]
246struct QueryMetrics {
247    query: EvaluationQuery,
248    millis_to_search: Duration,
249    ndcg: Vec<f32>,
250    map: Vec<f32>,
251    mrr: f32,
252    hits: Vec<usize>,
253    precision: Vec<f32>,
254    recall: Vec<f32>,
255}
256
257#[derive(Serialize)]
258struct SummaryMetrics {
259    millis_to_search: f32,
260    ndcg: Vec<f32>,
261    map: Vec<f32>,
262    mrr: f32,
263    precision: Vec<f32>,
264    recall: Vec<f32>,
265}
266
267#[derive(Serialize)]
268struct RepoEvaluationMetrics {
269    millis_to_index: Duration,
270    query_metrics: Vec<QueryMetrics>,
271    repo_metrics: Option<SummaryMetrics>,
272}
273
274impl RepoEvaluationMetrics {
275    fn new(millis_to_index: Duration) -> Self {
276        RepoEvaluationMetrics {
277            millis_to_index,
278            query_metrics: Vec::new(),
279            repo_metrics: None,
280        }
281    }
282
283    fn save(&self, repo_name: String) -> Result<()> {
284        let results_string = serde_json::to_string(&self)?;
285        fs::write(format!("./{}_evaluation.json", repo_name), results_string)
286            .expect("Unable to write file");
287        Ok(())
288    }
289
290    fn summarize(&mut self) {
291        let l = self.query_metrics.len() as f32;
292        let millis_to_search: f32 = self
293            .query_metrics
294            .iter()
295            .map(|metrics| metrics.millis_to_search.as_millis())
296            .sum::<u128>() as f32
297            / l;
298
299        let mut ndcg_sum = vec![0.0; 10];
300        let mut map_sum = vec![0.0; 10];
301        let mut precision_sum = vec![0.0; 10];
302        let mut recall_sum = vec![0.0; 10];
303        let mut mmr_sum = 0.0;
304
305        for query_metric in self.query_metrics.iter() {
306            for (ndcg, query_ndcg) in ndcg_sum.iter_mut().zip(query_metric.ndcg.clone()) {
307                *ndcg += query_ndcg;
308            }
309
310            for (mapp, query_map) in map_sum.iter_mut().zip(query_metric.map.clone()) {
311                *mapp += query_map;
312            }
313
314            for (pre, query_pre) in precision_sum.iter_mut().zip(query_metric.precision.clone()) {
315                *pre += query_pre;
316            }
317
318            for (rec, query_rec) in recall_sum.iter_mut().zip(query_metric.recall.clone()) {
319                *rec += query_rec;
320            }
321
322            mmr_sum += query_metric.mrr;
323        }
324
325        let ndcg = ndcg_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
326        let map = map_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
327        let precision = precision_sum
328            .iter()
329            .map(|val| val / l)
330            .collect::<Vec<f32>>();
331        let recall = recall_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
332        let mrr = mmr_sum / l;
333
334        self.repo_metrics = Some(SummaryMetrics {
335            millis_to_search,
336            ndcg,
337            map,
338            mrr,
339            precision,
340            recall,
341        })
342    }
343}
344
345fn evaluate_precision(hits: Vec<usize>) -> Vec<f32> {
346    let mut rolling_hit: f32 = 0.0;
347    let mut precision = Vec::new();
348    for (idx, hit) in hits.into_iter().enumerate() {
349        rolling_hit += hit as f32;
350        precision.push(rolling_hit / ((idx as f32) + 1.0));
351    }
352
353    precision
354}
355
356fn evaluate_recall(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
357    let total_relevant = ideal.iter().sum::<usize>() as f32;
358    let mut recall = Vec::new();
359    let mut rolling_hit: f32 = 0.0;
360    for hit in hits {
361        rolling_hit += hit as f32;
362        recall.push(rolling_hit / total_relevant);
363    }
364
365    recall
366}
367
368async fn evaluate_repo(
369    repo_name: String,
370    index: ModelHandle<SemanticIndex>,
371    project: ModelHandle<Project>,
372    query_matches: Vec<EvaluationQuery>,
373    cx: &mut AsyncAppContext,
374) -> Result<RepoEvaluationMetrics> {
375    // Index Project
376    let index_t0 = Instant::now();
377    index
378        .update(cx, |index, cx| index.index_project(project.clone(), cx))
379        .await?;
380    let mut repo_metrics = RepoEvaluationMetrics::new(index_t0.elapsed());
381
382    for query in query_matches {
383        // Query each match in order
384        let search_t0 = Instant::now();
385        let search_results = index
386            .update(cx, |index, cx| {
387                index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
388            })
389            .await?;
390        let millis_to_search = search_t0.elapsed();
391
392        // Get Hits/Ideal
393        let k = 10;
394        let (ideal, hits) = self::get_hits(query.clone(), search_results, k, cx);
395
396        // Evaluate ndcg@k, for k = 1, 3, 5, 10
397        let ndcg = evaluate_ndcg(hits.clone(), ideal.clone());
398
399        // Evaluate map@k, for k = 1, 3, 5, 10
400        let map = evaluate_map(hits.clone());
401
402        // Evaluate mrr
403        let mrr = evaluate_mrr(hits.clone());
404
405        // Evaluate precision
406        let precision = evaluate_precision(hits.clone());
407
408        // Evaluate Recall
409        let recall = evaluate_recall(hits.clone(), ideal);
410
411        let query_metrics = QueryMetrics {
412            query,
413            millis_to_search,
414            ndcg,
415            map,
416            mrr,
417            hits,
418            precision,
419            recall,
420        };
421
422        repo_metrics.query_metrics.push(query_metrics);
423    }
424
425    repo_metrics.summarize();
426    let _ = repo_metrics.save(repo_name);
427
428    anyhow::Ok(repo_metrics)
429}
430
431fn main() {
432    // Launch new repo as a new Zed workspace/project
433    let app = gpui::App::new(()).unwrap();
434    let fs = Arc::new(RealFs);
435    let http = http::client();
436    let http_client = http::client();
437    init_logger();
438
439    app.run(move |cx| {
440        cx.set_global(*RELEASE_CHANNEL);
441
442        let client = client::Client::new(http.clone(), cx);
443        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
444
445        // Initialize Settings
446        let mut store = SettingsStore::default();
447        store
448            .set_default_settings(default_settings().as_ref(), cx)
449            .unwrap();
450        cx.set_global(store);
451
452        // Initialize Languages
453        let login_shell_env_loaded = Task::ready(());
454        let mut languages = LanguageRegistry::new(login_shell_env_loaded);
455        languages.set_executor(cx.background().clone());
456        let languages = Arc::new(languages);
457
458        let node_runtime = RealNodeRuntime::new(http.clone());
459        languages::init(languages.clone(), node_runtime.clone());
460        language::init(cx);
461
462        project::Project::init(&client, cx);
463        semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
464
465        settings::register::<SemanticIndexSettings>(cx);
466
467        let db_file_path = EMBEDDINGS_DIR
468            .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
469            .join("embeddings_db");
470
471        let languages = languages.clone();
472        let fs = fs.clone();
473        cx.spawn(|mut cx| async move {
474            let semantic_index = SemanticIndex::new(
475                fs.clone(),
476                db_file_path,
477                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
478                languages.clone(),
479                cx.clone(),
480            )
481            .await?;
482
483            if let Ok(repo_evals) = parse_eval() {
484                for repo in repo_evals {
485                    let cloned = clone_repo(repo.clone());
486                    match cloned {
487                        Ok((repo_name, clone_path)) => {
488                            println!(
489                                "Cloned {:?} @ {:?} into {:?}",
490                                repo.repo, repo.commit, &clone_path
491                            );
492
493                            // Create Project
494                            let project = cx.update(|cx| {
495                                Project::local(
496                                    client.clone(),
497                                    user_store.clone(),
498                                    languages.clone(),
499                                    fs.clone(),
500                                    cx,
501                                )
502                            });
503
504                            // Register Worktree
505                            let _ = project
506                                .update(&mut cx, |project, cx| {
507                                    project.find_or_create_local_worktree(clone_path, true, cx)
508                                })
509                                .await;
510
511                            let _ = evaluate_repo(
512                                repo_name,
513                                semantic_index.clone(),
514                                project,
515                                repo.assertions,
516                                &mut cx,
517                            )
518                            .await?;
519                        }
520                        Err(err) => {
521                            println!("Error cloning: {:?}", err);
522                        }
523                    }
524                }
525            }
526
527            anyhow::Ok(())
528        })
529        .detach();
530    });
531}