1use anyhow::{anyhow, Result};
2use client::{self, UserStore};
3use collections::HashMap;
4use git2::{Object, Oid, Repository};
5use gpui::{AppContext, AssetSource, AsyncAppContext, ModelHandle, Task};
6use language::LanguageRegistry;
7use node_runtime::RealNodeRuntime;
8use project::{Project, RealFs};
9use rust_embed::RustEmbed;
10use semantic_index::embedding::OpenAIEmbeddings;
11use semantic_index::semantic_index_settings::SemanticIndexSettings;
12use semantic_index::{SearchResult, SemanticIndex};
13use serde::Deserialize;
14use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use std::{cmp, env, fs};
19use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
20use util::http::{self};
21use util::paths::{self, EMBEDDINGS_DIR};
22use zed::languages;
23
24#[derive(RustEmbed)]
25#[folder = "../../assets"]
26#[include = "fonts/**/*"]
27#[include = "icons/**/*"]
28#[include = "themes/**/*"]
29#[include = "sounds/**/*"]
30#[include = "*.md"]
31#[exclude = "*.DS_Store"]
32pub struct Assets;
33
34impl AssetSource for Assets {
35 fn load(&self, path: &str) -> Result<std::borrow::Cow<[u8]>> {
36 Self::get(path)
37 .map(|f| f.data)
38 .ok_or_else(|| anyhow!("could not find asset at path \"{}\"", path))
39 }
40
41 fn list(&self, path: &str) -> Vec<std::borrow::Cow<'static, str>> {
42 Self::iter().filter(|p| p.starts_with(path)).collect()
43 }
44}
45
46#[derive(Deserialize, Clone)]
47struct EvaluationQuery {
48 query: String,
49 matches: Vec<String>,
50}
51
52impl EvaluationQuery {
53 fn match_pairs(&self) -> Vec<(PathBuf, u32)> {
54 let mut pairs = Vec::new();
55 for match_identifier in self.matches.iter() {
56 let mut match_parts = match_identifier.split(":");
57
58 if let Some(file_path) = match_parts.next() {
59 if let Some(row_number) = match_parts.next() {
60 pairs.push((PathBuf::from(file_path), row_number.parse::<u32>().unwrap()));
61 }
62 }
63 }
64 pairs
65 }
66}
67
68#[derive(Deserialize, Clone)]
69struct RepoEval {
70 repo: String,
71 commit: String,
72 assertions: Vec<EvaluationQuery>,
73}
74
75struct EvaluationResults {
76 token_count: usize,
77 span_count: usize,
78 time_to_index: Duration,
79 time_to_search: Vec<Duration>,
80 ndcg: HashMap<usize, f32>,
81 map: HashMap<usize, f32>,
82}
83
84const TMP_REPO_PATH: &str = "eval_repos";
85
86fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
87 let eval_folder = env::current_dir()?
88 .as_path()
89 .parent()
90 .unwrap()
91 .join("crates/semantic_index/eval");
92
93 let mut repo_evals: Vec<RepoEval> = Vec::new();
94 for entry in fs::read_dir(eval_folder)? {
95 let file_path = entry.unwrap().path();
96 if let Some(extension) = file_path.extension() {
97 if extension == "json" {
98 if let Ok(file) = fs::read_to_string(file_path) {
99 let repo_eval = serde_json::from_str(file.as_str());
100
101 match repo_eval {
102 Ok(repo_eval) => {
103 repo_evals.push(repo_eval);
104 }
105 Err(err) => {
106 println!("Err: {:?}", err);
107 }
108 }
109 }
110 }
111 }
112 }
113
114 Ok(repo_evals)
115}
116
117fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
118 let repo_name = Path::new(repo_eval.repo.as_str())
119 .file_name()
120 .unwrap()
121 .to_str()
122 .unwrap()
123 .to_owned()
124 .replace(".git", "");
125
126 let clone_path = fs::canonicalize(env::current_dir()?)?
127 .parent()
128 .ok_or(anyhow!("path canonicalization failed"))?
129 .parent()
130 .unwrap()
131 .join(TMP_REPO_PATH)
132 .join(&repo_name);
133
134 // Delete Clone Path if already exists
135 let _ = fs::remove_dir_all(&clone_path);
136
137 // Clone in Repo
138 git2::build::RepoBuilder::new()
139 // .branch(repo_eval.sha.as_str())
140 .clone(repo_eval.repo.as_str(), clone_path.as_path())?;
141
142 let repo: Repository = Repository::open(clone_path.clone())?;
143 let obj: Object = repo
144 .find_commit(Oid::from_str(repo_eval.commit.as_str())?)?
145 .into_object();
146 repo.checkout_tree(&obj, None)?;
147 repo.set_head_detached(obj.id())?;
148
149 Ok(clone_path)
150}
151
152fn dcg(hits: Vec<usize>) -> f32 {
153 let mut result = 0.0;
154 for (idx, hit) in hits.iter().enumerate() {
155 result += *hit as f32 / (2.0 + idx as f32).log2();
156 }
157
158 result
159}
160
161fn get_hits(
162 eval_query: EvaluationQuery,
163 search_results: Vec<SearchResult>,
164 k: usize,
165 cx: &AsyncAppContext,
166) -> (Vec<usize>, Vec<usize>) {
167 let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
168
169 let mut hits = Vec::new();
170 for result in search_results {
171 let (path, start_row, end_row) = result.buffer.read_with(cx, |buffer, cx| {
172 let path = buffer.file().unwrap().path().to_path_buf();
173 let start_row = buffer.offset_to_point(result.range.start.offset).row;
174 let end_row = buffer.offset_to_point(result.range.end.offset).row;
175 (path, start_row, end_row)
176 });
177
178 let match_pairs = eval_query.match_pairs();
179 let mut found = 0;
180 for (match_path, match_row) in match_pairs {
181 if match_path == path {
182 if match_row >= start_row && match_row <= end_row {
183 found = 1;
184 break;
185 }
186 }
187 }
188
189 hits.push(found);
190 }
191
192 // For now, we are calculating ideal_hits a bit different, as technically
193 // with overlapping ranges, one match can result in more than result.
194 let mut ideal_hits = hits.clone();
195 ideal_hits.retain(|x| x == &1);
196
197 let ideal = if ideal.len() > ideal_hits.len() {
198 ideal
199 } else {
200 ideal_hits
201 };
202
203 // Fill ideal to 10 length
204 let mut filled_ideal = [0; 10];
205 for (idx, i) in ideal.to_vec().into_iter().enumerate() {
206 filled_ideal[idx] = i;
207 }
208
209 (filled_ideal.to_vec(), hits)
210}
211
212fn evaluate_ndcg(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
213 // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
214 // items returned by the search engine relative to the hypothetical ideal.
215 // Relevance is represented as a series of booleans, in which each search result returned
216 // is identified as being inside the test set of matches (1) or not (0).
217
218 // For example, if result 1, 3 and 5 match the 3 relevant results provided
219 // actual dcg is calculated against a vector of [1, 0, 1, 0, 1]
220 // whereas ideal dcg is calculated against a vector of [1, 1, 1, 0, 0]
221 // as this ideal vector assumes the 3 relevant results provided were returned first
222 // normalized dcg is then calculated as actual dcg / ideal dcg.
223
224 // NDCG ranges from 0 to 1, which higher values indicating better performance
225 // Commonly NDCG is expressed as NDCG@k, in which k represents the metric calculated
226 // including only the top k values returned.
227 // The @k metrics can help you identify, at what point does the relevant results start to fall off.
228 // Ie. a NDCG@1 of 0.9 and a NDCG@3 of 0.5 may indicate that the first result returned in usually
229 // very high quality, whereas rank results quickly drop off after the first result.
230
231 let mut ndcg = Vec::new();
232 for idx in 1..(hits.len() + 1) {
233 let hits_at_k = hits[0..idx].to_vec();
234 let ideal_at_k = ideal[0..idx].to_vec();
235
236 let at_k = dcg(hits_at_k.clone()) / dcg(ideal_at_k.clone());
237
238 ndcg.push(at_k);
239 }
240
241 ndcg
242}
243
244fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
245 let mut map_at_k = Vec::new();
246
247 let non_zero = hits.iter().sum::<usize>() as f32;
248 if non_zero == 0.0 {
249 return vec![0.0; hits.len()];
250 }
251
252 let mut rolling_non_zero = 0.0;
253 let mut rolling_map = 0.0;
254 for (idx, h) in hits.into_iter().enumerate() {
255 rolling_non_zero += h as f32;
256 rolling_map += rolling_non_zero / (idx + 1) as f32;
257 map_at_k.push(rolling_map / non_zero);
258 }
259
260 map_at_k
261}
262
263fn init_logger() {
264 env_logger::init();
265}
266
267async fn evaluate_repo(
268 index: ModelHandle<SemanticIndex>,
269 project: ModelHandle<Project>,
270 query_matches: Vec<EvaluationQuery>,
271 cx: &mut AsyncAppContext,
272) -> Result<()> {
273 // Index Project
274 let index_t0 = Instant::now();
275 index
276 .update(cx, |index, cx| index.index_project(project.clone(), cx))
277 .await?;
278 let index_time = index_t0.elapsed();
279 println!("Time to Index: {:?}", index_time.as_millis());
280
281 for query in query_matches {
282 // Query each match in order
283 let search_t0 = Instant::now();
284 let search_results = index
285 .update(cx, |index, cx| {
286 index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
287 })
288 .await?;
289 let search_time = search_t0.elapsed();
290 println!("Time to Search: {:?}", search_time.as_millis());
291
292 // Get Hits/Ideal
293 let k = 10;
294 let (ideal, hits) = self::get_hits(query, search_results, k, cx);
295
296 // Evaluate ndcg@k, for k = 1, 3, 5, 10
297 let ndcg = evaluate_ndcg(hits.clone(), ideal);
298 println!("NDCG: {:?}", ndcg);
299
300 // Evaluate map@k, for k = 1, 3, 5, 10
301 let map = evaluate_map(hits);
302 println!("MAP: {:?}", map);
303
304 // Evaluate span count
305 // Evaluate token count
306 }
307
308 anyhow::Ok(())
309}
310
311fn main() {
312 // Launch new repo as a new Zed workspace/project
313 let app = gpui::App::new(Assets).unwrap();
314 let fs = Arc::new(RealFs);
315 let http = http::client();
316 let user_settings_file_rx =
317 watch_config_file(app.background(), fs.clone(), paths::SETTINGS.clone());
318 let http_client = http::client();
319 init_logger();
320
321 app.run(move |cx| {
322 cx.set_global(*RELEASE_CHANNEL);
323
324 let client = client::Client::new(http.clone(), cx);
325 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
326
327 // Initialize Settings
328 let mut store = SettingsStore::default();
329 store
330 .set_default_settings(default_settings().as_ref(), cx)
331 .unwrap();
332 cx.set_global(store);
333 handle_settings_file_changes(user_settings_file_rx, cx);
334
335 // Initialize Languages
336 let login_shell_env_loaded = Task::ready(());
337 let mut languages = LanguageRegistry::new(login_shell_env_loaded);
338 languages.set_executor(cx.background().clone());
339 let languages = Arc::new(languages);
340
341 let node_runtime = RealNodeRuntime::new(http.clone());
342 languages::init(languages.clone(), node_runtime.clone());
343 language::init(cx);
344
345 project::Project::init(&client, cx);
346 semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
347
348 settings::register::<SemanticIndexSettings>(cx);
349
350 let db_file_path = EMBEDDINGS_DIR
351 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
352 .join("embeddings_db");
353
354 let languages = languages.clone();
355 let fs = fs.clone();
356 cx.spawn(|mut cx| async move {
357 let semantic_index = SemanticIndex::new(
358 fs.clone(),
359 db_file_path,
360 Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
361 languages.clone(),
362 cx.clone(),
363 )
364 .await?;
365
366 if let Ok(repo_evals) = parse_eval() {
367 for repo in repo_evals {
368 let cloned = clone_repo(repo.clone());
369 match cloned {
370 Ok(clone_path) => {
371 log::trace!(
372 "Cloned {:?} @ {:?} into {:?}",
373 repo.repo,
374 repo.commit,
375 &clone_path
376 );
377
378 // Create Project
379 let project = cx.update(|cx| {
380 Project::local(
381 client.clone(),
382 user_store.clone(),
383 languages.clone(),
384 fs.clone(),
385 cx,
386 )
387 });
388
389 // Register Worktree
390 let _ = project
391 .update(&mut cx, |project, cx| {
392 project.find_or_create_local_worktree(clone_path, true, cx)
393 })
394 .await;
395
396 evaluate_repo(
397 semantic_index.clone(),
398 project,
399 repo.assertions,
400 &mut cx,
401 )
402 .await?;
403 }
404 Err(err) => {
405 log::trace!("Error cloning: {:?}", err);
406 }
407 }
408 }
409 }
410
411 anyhow::Ok(())
412 })
413 .detach();
414 });
415}