1use ai::embedding::OpenAIEmbeddings;
2use anyhow::{anyhow, Result};
3use client::{self, UserStore};
4use gpui::{AsyncAppContext, ModelHandle, Task};
5use language::LanguageRegistry;
6use node_runtime::RealNodeRuntime;
7use project::{Project, RealFs};
8use semantic_index::semantic_index_settings::SemanticIndexSettings;
9use semantic_index::{SearchResult, SemanticIndex};
10use serde::{Deserialize, Serialize};
11use settings::{default_settings, SettingsStore};
12use std::path::{Path, PathBuf};
13use std::process::Command;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{cmp, env, fs};
17use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
18use util::http::{self};
19use util::paths::EMBEDDINGS_DIR;
20use zed::languages;
21
22#[derive(Deserialize, Clone, Serialize)]
23struct EvaluationQuery {
24 query: String,
25 matches: Vec<String>,
26}
27
28impl EvaluationQuery {
29 fn match_pairs(&self) -> Vec<(PathBuf, u32)> {
30 let mut pairs = Vec::new();
31 for match_identifier in self.matches.iter() {
32 let mut match_parts = match_identifier.split(":");
33
34 if let Some(file_path) = match_parts.next() {
35 if let Some(row_number) = match_parts.next() {
36 pairs.push((PathBuf::from(file_path), row_number.parse::<u32>().unwrap()));
37 }
38 }
39 }
40 pairs
41 }
42}
43
44#[derive(Deserialize, Clone)]
45struct RepoEval {
46 repo: String,
47 commit: String,
48 assertions: Vec<EvaluationQuery>,
49}
50
51const TMP_REPO_PATH: &str = "eval_repos";
52
53fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
54 let eval_folder = env::current_dir()?
55 .as_path()
56 .parent()
57 .unwrap()
58 .join("crates/semantic_index/eval");
59
60 let mut repo_evals: Vec<RepoEval> = Vec::new();
61 for entry in fs::read_dir(eval_folder)? {
62 let file_path = entry.unwrap().path();
63 if let Some(extension) = file_path.extension() {
64 if extension == "json" {
65 if let Ok(file) = fs::read_to_string(file_path) {
66 let repo_eval = serde_json::from_str(file.as_str());
67
68 match repo_eval {
69 Ok(repo_eval) => {
70 repo_evals.push(repo_eval);
71 }
72 Err(err) => {
73 println!("Err: {:?}", err);
74 }
75 }
76 }
77 }
78 }
79 }
80
81 Ok(repo_evals)
82}
83
84fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<(String, PathBuf)> {
85 let repo_name = Path::new(repo_eval.repo.as_str())
86 .file_name()
87 .unwrap()
88 .to_str()
89 .unwrap()
90 .to_owned()
91 .replace(".git", "");
92
93 let clone_path = fs::canonicalize(env::current_dir()?)?
94 .parent()
95 .ok_or(anyhow!("path canonicalization failed"))?
96 .parent()
97 .unwrap()
98 .join(TMP_REPO_PATH);
99
100 // Delete Clone Path if already exists
101 let _ = fs::remove_dir_all(&clone_path);
102 let _ = fs::create_dir(&clone_path);
103
104 let _ = Command::new("git")
105 .args(["clone", repo_eval.repo.as_str()])
106 .current_dir(clone_path.clone())
107 .output()?;
108 // Update clone path to be new directory housing the repo.
109 let clone_path = clone_path.join(repo_name.clone());
110 let _ = Command::new("git")
111 .args(["checkout", repo_eval.commit.as_str()])
112 .current_dir(clone_path.clone())
113 .output()?;
114
115 Ok((repo_name, clone_path))
116}
117
118fn dcg(hits: Vec<usize>) -> f32 {
119 let mut result = 0.0;
120 for (idx, hit) in hits.iter().enumerate() {
121 result += *hit as f32 / (2.0 + idx as f32).log2();
122 }
123
124 result
125}
126
127fn get_hits(
128 eval_query: EvaluationQuery,
129 search_results: Vec<SearchResult>,
130 k: usize,
131 cx: &AsyncAppContext,
132) -> (Vec<usize>, Vec<usize>) {
133 let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
134
135 let mut hits = Vec::new();
136 for result in search_results {
137 let (path, start_row, end_row) = result.buffer.read_with(cx, |buffer, _cx| {
138 let path = buffer.file().unwrap().path().to_path_buf();
139 let start_row = buffer.offset_to_point(result.range.start.offset).row;
140 let end_row = buffer.offset_to_point(result.range.end.offset).row;
141 (path, start_row, end_row)
142 });
143
144 let match_pairs = eval_query.match_pairs();
145 let mut found = 0;
146 for (match_path, match_row) in match_pairs {
147 if match_path == path {
148 if match_row >= start_row && match_row <= end_row {
149 found = 1;
150 break;
151 }
152 }
153 }
154
155 hits.push(found);
156 }
157
158 // For now, we are calculating ideal_hits a bit different, as technically
159 // with overlapping ranges, one match can result in more than result.
160 let mut ideal_hits = hits.clone();
161 ideal_hits.retain(|x| x == &1);
162
163 let ideal = if ideal.len() > ideal_hits.len() {
164 ideal
165 } else {
166 ideal_hits
167 };
168
169 // Fill ideal to 10 length
170 let mut filled_ideal = [0; 10];
171 for (idx, i) in ideal.to_vec().into_iter().enumerate() {
172 filled_ideal[idx] = i;
173 }
174
175 (filled_ideal.to_vec(), hits)
176}
177
178fn evaluate_ndcg(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
179 // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
180 // items returned by the search engine relative to the hypothetical ideal.
181 // Relevance is represented as a series of booleans, in which each search result returned
182 // is identified as being inside the test set of matches (1) or not (0).
183
184 // For example, if result 1, 3 and 5 match the 3 relevant results provided
185 // actual dcg is calculated against a vector of [1, 0, 1, 0, 1]
186 // whereas ideal dcg is calculated against a vector of [1, 1, 1, 0, 0]
187 // as this ideal vector assumes the 3 relevant results provided were returned first
188 // normalized dcg is then calculated as actual dcg / ideal dcg.
189
190 // NDCG ranges from 0 to 1, which higher values indicating better performance
191 // Commonly NDCG is expressed as NDCG@k, in which k represents the metric calculated
192 // including only the top k values returned.
193 // The @k metrics can help you identify, at what point does the relevant results start to fall off.
194 // Ie. a NDCG@1 of 0.9 and a NDCG@3 of 0.5 may indicate that the first result returned in usually
195 // very high quality, whereas rank results quickly drop off after the first result.
196
197 let mut ndcg = Vec::new();
198 for idx in 1..(hits.len() + 1) {
199 let hits_at_k = hits[0..idx].to_vec();
200 let ideal_at_k = ideal[0..idx].to_vec();
201
202 let at_k = dcg(hits_at_k.clone()) / dcg(ideal_at_k.clone());
203
204 ndcg.push(at_k);
205 }
206
207 ndcg
208}
209
210fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
211 let mut map_at_k = Vec::new();
212
213 let non_zero = hits.iter().sum::<usize>() as f32;
214 if non_zero == 0.0 {
215 return vec![0.0; hits.len()];
216 }
217
218 let mut rolling_non_zero = 0.0;
219 let mut rolling_map = 0.0;
220 for (idx, h) in hits.into_iter().enumerate() {
221 rolling_non_zero += h as f32;
222 if h == 1 {
223 rolling_map += rolling_non_zero / (idx + 1) as f32;
224 }
225 map_at_k.push(rolling_map / non_zero);
226 }
227
228 map_at_k
229}
230
231fn evaluate_mrr(hits: Vec<usize>) -> f32 {
232 for (idx, h) in hits.into_iter().enumerate() {
233 if h == 1 {
234 return 1.0 / (idx + 1) as f32;
235 }
236 }
237
238 return 0.0;
239}
240
241fn init_logger() {
242 env_logger::init();
243}
244
245#[derive(Serialize)]
246struct QueryMetrics {
247 query: EvaluationQuery,
248 millis_to_search: Duration,
249 ndcg: Vec<f32>,
250 map: Vec<f32>,
251 mrr: f32,
252 hits: Vec<usize>,
253 precision: Vec<f32>,
254 recall: Vec<f32>,
255}
256
257#[derive(Serialize)]
258struct SummaryMetrics {
259 millis_to_search: f32,
260 ndcg: Vec<f32>,
261 map: Vec<f32>,
262 mrr: f32,
263 precision: Vec<f32>,
264 recall: Vec<f32>,
265}
266
267#[derive(Serialize)]
268struct RepoEvaluationMetrics {
269 millis_to_index: Duration,
270 query_metrics: Vec<QueryMetrics>,
271 repo_metrics: Option<SummaryMetrics>,
272}
273
274impl RepoEvaluationMetrics {
275 fn new(millis_to_index: Duration) -> Self {
276 RepoEvaluationMetrics {
277 millis_to_index,
278 query_metrics: Vec::new(),
279 repo_metrics: None,
280 }
281 }
282
283 fn save(&self, repo_name: String) -> Result<()> {
284 let results_string = serde_json::to_string(&self)?;
285 fs::write(format!("./{}_evaluation.json", repo_name), results_string)
286 .expect("Unable to write file");
287 Ok(())
288 }
289
290 fn summarize(&mut self) {
291 let l = self.query_metrics.len() as f32;
292 let millis_to_search: f32 = self
293 .query_metrics
294 .iter()
295 .map(|metrics| metrics.millis_to_search.as_millis())
296 .sum::<u128>() as f32
297 / l;
298
299 let mut ndcg_sum = vec![0.0; 10];
300 let mut map_sum = vec![0.0; 10];
301 let mut precision_sum = vec![0.0; 10];
302 let mut recall_sum = vec![0.0; 10];
303 let mut mmr_sum = 0.0;
304
305 for query_metric in self.query_metrics.iter() {
306 for (ndcg, query_ndcg) in ndcg_sum.iter_mut().zip(query_metric.ndcg.clone()) {
307 *ndcg += query_ndcg;
308 }
309
310 for (mapp, query_map) in map_sum.iter_mut().zip(query_metric.map.clone()) {
311 *mapp += query_map;
312 }
313
314 for (pre, query_pre) in precision_sum.iter_mut().zip(query_metric.precision.clone()) {
315 *pre += query_pre;
316 }
317
318 for (rec, query_rec) in recall_sum.iter_mut().zip(query_metric.recall.clone()) {
319 *rec += query_rec;
320 }
321
322 mmr_sum += query_metric.mrr;
323 }
324
325 let ndcg = ndcg_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
326 let map = map_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
327 let precision = precision_sum
328 .iter()
329 .map(|val| val / l)
330 .collect::<Vec<f32>>();
331 let recall = recall_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
332 let mrr = mmr_sum / l;
333
334 self.repo_metrics = Some(SummaryMetrics {
335 millis_to_search,
336 ndcg,
337 map,
338 mrr,
339 precision,
340 recall,
341 })
342 }
343}
344
345fn evaluate_precision(hits: Vec<usize>) -> Vec<f32> {
346 let mut rolling_hit: f32 = 0.0;
347 let mut precision = Vec::new();
348 for (idx, hit) in hits.into_iter().enumerate() {
349 rolling_hit += hit as f32;
350 precision.push(rolling_hit / ((idx as f32) + 1.0));
351 }
352
353 precision
354}
355
356fn evaluate_recall(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
357 let total_relevant = ideal.iter().sum::<usize>() as f32;
358 let mut recall = Vec::new();
359 let mut rolling_hit: f32 = 0.0;
360 for hit in hits {
361 rolling_hit += hit as f32;
362 recall.push(rolling_hit / total_relevant);
363 }
364
365 recall
366}
367
368async fn evaluate_repo(
369 repo_name: String,
370 index: ModelHandle<SemanticIndex>,
371 project: ModelHandle<Project>,
372 query_matches: Vec<EvaluationQuery>,
373 cx: &mut AsyncAppContext,
374) -> Result<RepoEvaluationMetrics> {
375 // Index Project
376 let index_t0 = Instant::now();
377 index
378 .update(cx, |index, cx| index.index_project(project.clone(), cx))
379 .await?;
380 let mut repo_metrics = RepoEvaluationMetrics::new(index_t0.elapsed());
381
382 for query in query_matches {
383 // Query each match in order
384 let search_t0 = Instant::now();
385 let search_results = index
386 .update(cx, |index, cx| {
387 index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
388 })
389 .await?;
390 let millis_to_search = search_t0.elapsed();
391
392 // Get Hits/Ideal
393 let k = 10;
394 let (ideal, hits) = self::get_hits(query.clone(), search_results, k, cx);
395
396 // Evaluate ndcg@k, for k = 1, 3, 5, 10
397 let ndcg = evaluate_ndcg(hits.clone(), ideal.clone());
398
399 // Evaluate map@k, for k = 1, 3, 5, 10
400 let map = evaluate_map(hits.clone());
401
402 // Evaluate mrr
403 let mrr = evaluate_mrr(hits.clone());
404
405 // Evaluate precision
406 let precision = evaluate_precision(hits.clone());
407
408 // Evaluate Recall
409 let recall = evaluate_recall(hits.clone(), ideal);
410
411 let query_metrics = QueryMetrics {
412 query,
413 millis_to_search,
414 ndcg,
415 map,
416 mrr,
417 hits,
418 precision,
419 recall,
420 };
421
422 repo_metrics.query_metrics.push(query_metrics);
423 }
424
425 repo_metrics.summarize();
426 let _ = repo_metrics.save(repo_name);
427
428 anyhow::Ok(repo_metrics)
429}
430
431fn main() {
432 // Launch new repo as a new Zed workspace/project
433 let app = gpui::App::new(()).unwrap();
434 let fs = Arc::new(RealFs);
435 let http = http::client();
436 let http_client = http::client();
437 init_logger();
438
439 app.run(move |cx| {
440 cx.set_global(*RELEASE_CHANNEL);
441
442 let client = client::Client::new(http.clone(), cx);
443 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
444
445 // Initialize Settings
446 let mut store = SettingsStore::default();
447 store
448 .set_default_settings(default_settings().as_ref(), cx)
449 .unwrap();
450 cx.set_global(store);
451
452 // Initialize Languages
453 let login_shell_env_loaded = Task::ready(());
454 let mut languages = LanguageRegistry::new(login_shell_env_loaded);
455 languages.set_executor(cx.background().clone());
456 let languages = Arc::new(languages);
457
458 let node_runtime = RealNodeRuntime::new(http.clone());
459 languages::init(languages.clone(), node_runtime.clone(), cx);
460 language::init(cx);
461
462 project::Project::init(&client, cx);
463 semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
464
465 settings::register::<SemanticIndexSettings>(cx);
466
467 let db_file_path = EMBEDDINGS_DIR
468 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
469 .join("embeddings_db");
470
471 let languages = languages.clone();
472 let fs = fs.clone();
473 cx.spawn(|mut cx| async move {
474 let semantic_index = SemanticIndex::new(
475 fs.clone(),
476 db_file_path,
477 Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
478 languages.clone(),
479 cx.clone(),
480 )
481 .await?;
482
483 if let Ok(repo_evals) = parse_eval() {
484 for repo in repo_evals {
485 let cloned = clone_repo(repo.clone());
486 match cloned {
487 Ok((repo_name, clone_path)) => {
488 println!(
489 "Cloned {:?} @ {:?} into {:?}",
490 repo.repo, repo.commit, &clone_path
491 );
492
493 // Create Project
494 let project = cx.update(|cx| {
495 Project::local(
496 client.clone(),
497 user_store.clone(),
498 languages.clone(),
499 fs.clone(),
500 cx,
501 )
502 });
503
504 // Register Worktree
505 let _ = project
506 .update(&mut cx, |project, cx| {
507 project.find_or_create_local_worktree(clone_path, true, cx)
508 })
509 .await;
510
511 let _ = evaluate_repo(
512 repo_name,
513 semantic_index.clone(),
514 project,
515 repo.assertions,
516 &mut cx,
517 )
518 .await?;
519 }
520 Err(err) => {
521 println!("Error cloning: {:?}", err);
522 }
523 }
524 }
525 }
526
527 anyhow::Ok(())
528 })
529 .detach();
530 });
531}