1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::FakeNodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
16use serde::{Deserialize, Serialize};
17use settings::SettingsStore;
18use smol::channel::bounded;
19use smol::io::AsyncReadExt;
20use smol::Timer;
21use std::ops::RangeInclusive;
22use std::time::Duration;
23use std::{
24 fs,
25 path::Path,
26 process::{exit, Command, Stdio},
27 sync::{
28 atomic::{AtomicUsize, Ordering::SeqCst},
29 Arc,
30 },
31};
32
33const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
34const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
35const EVAL_DB_PATH: &'static str = "target/eval_db";
36const SEARCH_RESULT_LIMIT: usize = 8;
37const SKIP_EVAL_PATH: &'static str = ".skip_eval";
38
39#[derive(clap::Parser)]
40#[command(author, version, about, long_about = None)]
41struct Cli {
42 #[command(subcommand)]
43 command: Commands,
44}
45
46#[derive(clap::Subcommand)]
47enum Commands {
48 Fetch {},
49 Run {
50 #[arg(long)]
51 repo: Option<String>,
52 },
53}
54
55#[derive(Clone, Deserialize, Serialize)]
56struct EvaluationProject {
57 repo: String,
58 sha: String,
59 queries: Vec<EvaluationQuery>,
60}
61
62#[derive(Clone, Debug, Deserialize, Serialize)]
63struct EvaluationQuery {
64 query: String,
65 expected_results: Vec<EvaluationSearchResult>,
66}
67
68#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
69struct EvaluationSearchResult {
70 file: String,
71 lines: RangeInclusive<u32>,
72}
73
74#[derive(Clone, Deserialize, Serialize)]
75struct EvaluationProjectOutcome {
76 repo: String,
77 sha: String,
78 queries: Vec<EvaluationQueryOutcome>,
79}
80
81#[derive(Clone, Debug, Deserialize, Serialize)]
82struct EvaluationQueryOutcome {
83 repo: String,
84 query: String,
85 expected_results: Vec<EvaluationSearchResult>,
86 actual_results: Vec<EvaluationSearchResult>,
87 covered_file_count: usize,
88 overlapped_result_count: usize,
89 covered_result_count: usize,
90 total_result_count: usize,
91 covered_result_indices: Vec<usize>,
92}
93
94fn main() -> Result<()> {
95 let cli = Cli::parse();
96 env_logger::init();
97
98 gpui::App::headless().run(move |cx| {
99 let executor = cx.background_executor().clone();
100
101 match cli.command {
102 Commands::Fetch {} => {
103 executor
104 .clone()
105 .spawn(async move {
106 if let Err(err) = fetch_evaluation_resources(&executor).await {
107 eprintln!("Error: {}", err);
108 exit(1);
109 }
110 exit(0);
111 })
112 .detach();
113 }
114 Commands::Run { repo } => {
115 cx.spawn(|mut cx| async move {
116 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
117 eprintln!("Error: {}", err);
118 exit(1);
119 }
120 exit(0);
121 })
122 .detach();
123 }
124 }
125 });
126
127 Ok(())
128}
129
130async fn fetch_evaluation_resources(executor: &BackgroundExecutor) -> Result<()> {
131 let http_client = http_client::HttpClientWithProxy::new(None, None);
132 fetch_code_search_net_resources(&http_client).await?;
133 fetch_eval_repos(executor, &http_client).await?;
134 Ok(())
135}
136
137async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
138 eprintln!("Fetching CodeSearchNet evaluations...");
139
140 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
141
142 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
143 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
144
145 // Fetch the annotations CSV, which contains the human-annotated search relevances
146 let annotations_path = dataset_dir.join("annotations.csv");
147 let annotations_csv_content = if annotations_path.exists() {
148 fs::read_to_string(&annotations_path).expect("failed to read annotations")
149 } else {
150 let response = http_client
151 .get(annotations_url, Default::default(), true)
152 .await
153 .expect("failed to fetch annotations csv");
154 let mut body = String::new();
155 response
156 .into_body()
157 .read_to_string(&mut body)
158 .await
159 .expect("failed to read annotations.csv response");
160 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
161 body
162 };
163
164 // Parse the annotations CSV. Skip over queries with zero relevance.
165 let rows = annotations_csv_content.lines().filter_map(|line| {
166 let mut values = line.split(',');
167 let _language = values.next()?;
168 let query = values.next()?;
169 let github_url = values.next()?;
170 let score = values.next()?;
171
172 if score == "0" {
173 return None;
174 }
175
176 let url_path = github_url.strip_prefix("https://github.com/")?;
177 let (url_path, hash) = url_path.split_once('#')?;
178 let (repo_name, url_path) = url_path.split_once("/blob/")?;
179 let (sha, file_path) = url_path.split_once('/')?;
180 let line_range = if let Some((start, end)) = hash.split_once('-') {
181 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
182 } else {
183 let row = hash.strip_prefix("L")?.parse().ok()?;
184 row..=row
185 };
186 Some((repo_name, sha, query, file_path, line_range))
187 });
188
189 // Group the annotations by repo and sha.
190 let mut evaluations_by_repo = BTreeMap::new();
191 for (repo_name, sha, query, file_path, lines) in rows {
192 let evaluation_project = evaluations_by_repo
193 .entry((repo_name, sha))
194 .or_insert_with(|| EvaluationProject {
195 repo: repo_name.to_string(),
196 sha: sha.to_string(),
197 queries: Vec::new(),
198 });
199
200 let ix = evaluation_project
201 .queries
202 .iter()
203 .position(|entry| entry.query == query)
204 .unwrap_or_else(|| {
205 evaluation_project.queries.push(EvaluationQuery {
206 query: query.to_string(),
207 expected_results: Vec::new(),
208 });
209 evaluation_project.queries.len() - 1
210 });
211 let results = &mut evaluation_project.queries[ix].expected_results;
212 let result = EvaluationSearchResult {
213 file: file_path.to_string(),
214 lines,
215 };
216 if !results.contains(&result) {
217 results.push(result);
218 }
219 }
220
221 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
222 let evaluations_path = dataset_dir.join("evaluations.json");
223 fs::write(
224 &evaluations_path,
225 serde_json::to_vec_pretty(&evaluations).unwrap(),
226 )
227 .unwrap();
228
229 eprintln!(
230 "Fetched CodeSearchNet evaluations into {}",
231 evaluations_path.display()
232 );
233
234 Ok(())
235}
236
237async fn run_evaluation(
238 only_repo: Option<String>,
239 executor: &BackgroundExecutor,
240 cx: &mut AsyncAppContext,
241) -> Result<()> {
242 cx.update(|cx| {
243 let mut store = SettingsStore::new(cx);
244 store
245 .set_default_settings(settings::default_settings().as_ref(), cx)
246 .unwrap();
247 cx.set_global(store);
248 client::init_settings(cx);
249 language::init(cx);
250 Project::init_settings(cx);
251 cx.update_flags(false, vec![]);
252 })
253 .unwrap();
254
255 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
256 let evaluations_path = dataset_dir.join("evaluations.json");
257 let repos_dir = Path::new(EVAL_REPOS_DIR);
258 let db_path = Path::new(EVAL_DB_PATH);
259 let http_client = http_client::HttpClientWithProxy::new(None, None);
260 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
261 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
262 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
263 let clock = Arc::new(RealSystemClock);
264 let client = cx
265 .update(|cx| {
266 Client::new(
267 clock,
268 Arc::new(http_client::HttpClientWithUrl::new(
269 "https://zed.dev",
270 None,
271 None,
272 )),
273 cx,
274 )
275 })
276 .unwrap();
277 let user_store = cx
278 .new_model(|cx| UserStore::new(client.clone(), cx))
279 .unwrap();
280 let node_runtime = Arc::new(FakeNodeRuntime {});
281
282 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
283 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
284
285 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
286 http_client.clone(),
287 OpenAiEmbeddingModel::TextEmbedding3Small,
288 open_ai::OPEN_AI_API_URL.to_string(),
289 api_key,
290 ));
291
292 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
293 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
294 .unwrap();
295
296 let mut covered_result_count = 0;
297 let mut overlapped_result_count = 0;
298 let mut covered_file_count = 0;
299 let mut total_result_count = 0;
300 eprint!("Running evals.");
301
302 for evaluation_project in evaluations {
303 if only_repo
304 .as_ref()
305 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
306 {
307 continue;
308 }
309
310 eprint!("\r\x1B[2K");
311 eprint!(
312 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
313 covered_result_count,
314 total_result_count,
315 overlapped_result_count,
316 total_result_count,
317 covered_file_count,
318 total_result_count,
319 evaluation_project.repo
320 );
321
322 let repo_db_path =
323 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
324 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
325 .await
326 .unwrap();
327
328 let repo_dir = repos_dir.join(&evaluation_project.repo);
329 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
330 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
331 continue;
332 }
333
334 let project = cx
335 .update(|cx| {
336 Project::local(
337 client.clone(),
338 node_runtime.clone(),
339 user_store.clone(),
340 language_registry.clone(),
341 fs.clone(),
342 None,
343 cx,
344 )
345 })
346 .unwrap();
347
348 let (worktree, _) = project
349 .update(cx, |project, cx| {
350 project.find_or_create_worktree(repo_dir, true, cx)
351 })?
352 .await?;
353
354 worktree
355 .update(cx, |worktree, _| {
356 worktree.as_local().unwrap().scan_complete()
357 })
358 .unwrap()
359 .await;
360
361 let project_index = cx
362 .update(|cx| semantic_index.create_project_index(project.clone(), cx))
363 .unwrap();
364 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
365
366 for query in evaluation_project.queries {
367 let results = cx
368 .update(|cx| {
369 let project_index = project_index.read(cx);
370 project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
371 })
372 .unwrap()
373 .await
374 .unwrap();
375
376 let results = SemanticDb::load_results(results, &fs.clone(), &cx)
377 .await
378 .unwrap();
379
380 let mut project_covered_result_count = 0;
381 let mut project_overlapped_result_count = 0;
382 let mut project_covered_file_count = 0;
383 let mut covered_result_indices = Vec::new();
384 for expected_result in &query.expected_results {
385 let mut file_matched = false;
386 let mut range_overlapped = false;
387 let mut range_covered = false;
388
389 for (ix, result) in results.iter().enumerate() {
390 if result.path.as_ref() == Path::new(&expected_result.file) {
391 file_matched = true;
392 let start_matched =
393 result.row_range.contains(&expected_result.lines.start());
394 let end_matched = result.row_range.contains(&expected_result.lines.end());
395
396 if start_matched || end_matched {
397 range_overlapped = true;
398 }
399
400 if start_matched && end_matched {
401 range_covered = true;
402 covered_result_indices.push(ix);
403 break;
404 }
405 }
406 }
407
408 if range_covered {
409 project_covered_result_count += 1
410 };
411 if range_overlapped {
412 project_overlapped_result_count += 1
413 };
414 if file_matched {
415 project_covered_file_count += 1
416 };
417 }
418 let outcome_repo = evaluation_project.repo.clone();
419
420 let query_results = EvaluationQueryOutcome {
421 repo: outcome_repo,
422 query: query.query,
423 total_result_count: query.expected_results.len(),
424 covered_result_count: project_covered_result_count,
425 overlapped_result_count: project_overlapped_result_count,
426 covered_file_count: project_covered_file_count,
427 expected_results: query.expected_results,
428 actual_results: results
429 .iter()
430 .map(|result| EvaluationSearchResult {
431 file: result.path.to_string_lossy().to_string(),
432 lines: result.row_range.clone(),
433 })
434 .collect(),
435 covered_result_indices,
436 };
437
438 overlapped_result_count += query_results.overlapped_result_count;
439 covered_result_count += query_results.covered_result_count;
440 covered_file_count += query_results.covered_file_count;
441 total_result_count += query_results.total_result_count;
442
443 println!("{}", serde_json::to_string(&query_results).unwrap());
444 }
445 }
446
447 eprint!(
448 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
449 covered_result_count,
450 total_result_count,
451 overlapped_result_count,
452 total_result_count,
453 covered_file_count,
454 total_result_count,
455 );
456
457 Ok(())
458}
459
460async fn wait_for_indexing_complete(
461 project_index: &Model<ProjectIndex>,
462 cx: &mut AsyncAppContext,
463 timeout: Option<Duration>,
464) {
465 let (tx, rx) = bounded(1);
466 let subscription = cx.update(|cx| {
467 cx.subscribe(project_index, move |_, event, _| {
468 if let Status::Idle = event {
469 let _ = tx.try_send(*event);
470 }
471 })
472 });
473
474 let result = match timeout {
475 Some(timeout_duration) => {
476 smol::future::or(
477 async {
478 rx.recv().await.map_err(|_| ())?;
479 Ok(())
480 },
481 async {
482 Timer::after(timeout_duration).await;
483 Err(())
484 },
485 )
486 .await
487 }
488 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
489 };
490
491 match result {
492 Ok(_) => (),
493 Err(_) => {
494 if let Some(timeout) = timeout {
495 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
496 }
497 }
498 }
499
500 drop(subscription);
501}
502
503async fn fetch_eval_repos(
504 executor: &BackgroundExecutor,
505 http_client: &dyn HttpClient,
506) -> Result<()> {
507 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
508 let evaluations_path = dataset_dir.join("evaluations.json");
509 let repos_dir = Path::new(EVAL_REPOS_DIR);
510
511 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
512 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
513
514 eprint!("Fetching evaluation repositories...");
515
516 executor
517 .scoped(move |scope| {
518 let done_count = Arc::new(AtomicUsize::new(0));
519 let len = evaluations.len();
520 for chunk in evaluations.chunks(evaluations.len() / 8) {
521 let chunk = chunk.to_vec();
522 let done_count = done_count.clone();
523 scope.spawn(async move {
524 for EvaluationProject { repo, sha, .. } in chunk {
525 eprint!(
526 "\rFetching evaluation repositories ({}/{})...",
527 done_count.load(SeqCst),
528 len,
529 );
530
531 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
532 done_count.fetch_add(1, SeqCst);
533 }
534 });
535 }
536 })
537 .await;
538
539 Ok(())
540}
541
542async fn fetch_eval_repo(
543 repo: String,
544 sha: String,
545 repos_dir: &Path,
546 http_client: &dyn HttpClient,
547) {
548 let Some((owner, repo_name)) = repo.split_once('/') else {
549 return;
550 };
551 let repo_dir = repos_dir.join(owner).join(repo_name);
552 fs::create_dir_all(&repo_dir).unwrap();
553 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
554 if skip_eval_path.exists() {
555 return;
556 }
557 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
558 if head_content.trim() == sha {
559 return;
560 }
561 }
562 let repo_response = http_client
563 .send(
564 http_client::Request::builder()
565 .method(Method::HEAD)
566 .uri(format!("https://github.com/{}", repo))
567 .body(Default::default())
568 .expect(""),
569 )
570 .await
571 .expect("failed to check github repo");
572 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
573 fs::write(&skip_eval_path, "").unwrap();
574 eprintln!(
575 "Repo {repo} is no longer public ({:?}). Skipping",
576 repo_response.status()
577 );
578 return;
579 }
580 if !repo_dir.join(".git").exists() {
581 let init_output = Command::new("git")
582 .current_dir(&repo_dir)
583 .args(&["init"])
584 .output()
585 .unwrap();
586 if !init_output.status.success() {
587 eprintln!(
588 "Failed to initialize git repository for {}: {}",
589 repo,
590 String::from_utf8_lossy(&init_output.stderr)
591 );
592 return;
593 }
594 }
595 let url = format!("https://github.com/{}.git", repo);
596 Command::new("git")
597 .current_dir(&repo_dir)
598 .args(&["remote", "add", "-f", "origin", &url])
599 .stdin(Stdio::null())
600 .output()
601 .unwrap();
602 let fetch_output = Command::new("git")
603 .current_dir(&repo_dir)
604 .args(&["fetch", "--depth", "1", "origin", &sha])
605 .stdin(Stdio::null())
606 .output()
607 .unwrap();
608 if !fetch_output.status.success() {
609 eprintln!(
610 "Failed to fetch {} for {}: {}",
611 sha,
612 repo,
613 String::from_utf8_lossy(&fetch_output.stderr)
614 );
615 return;
616 }
617 let checkout_output = Command::new("git")
618 .current_dir(&repo_dir)
619 .args(&["checkout", &sha])
620 .output()
621 .unwrap();
622
623 if !checkout_output.status.success() {
624 eprintln!(
625 "Failed to checkout {} for {}: {}",
626 sha,
627 repo,
628 String::from_utf8_lossy(&checkout_output.stderr)
629 );
630 }
631}