1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::NodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use semantic_index::{
16 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
17};
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::channel::bounded;
21use smol::io::AsyncReadExt;
22use smol::Timer;
23use std::ops::RangeInclusive;
24use std::path::PathBuf;
25use std::time::Duration;
26use std::{
27 fs,
28 path::Path,
29 process::{exit, Command, Stdio},
30 sync::{
31 atomic::{AtomicUsize, Ordering::SeqCst},
32 Arc,
33 },
34};
35
36const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
37const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
38const EVAL_DB_PATH: &'static str = "target/eval_db";
39const SEARCH_RESULT_LIMIT: usize = 8;
40const SKIP_EVAL_PATH: &'static str = ".skip_eval";
41
42#[derive(clap::Parser)]
43#[command(author, version, about, long_about = None)]
44struct Cli {
45 #[command(subcommand)]
46 command: Commands,
47}
48
49#[derive(clap::Subcommand)]
50enum Commands {
51 Fetch {},
52 Run {
53 #[arg(long)]
54 repo: Option<String>,
55 },
56}
57
58#[derive(Clone, Deserialize, Serialize)]
59struct EvaluationProject {
60 repo: String,
61 sha: String,
62 queries: Vec<EvaluationQuery>,
63}
64
65#[derive(Clone, Debug, Deserialize, Serialize)]
66struct EvaluationQuery {
67 query: String,
68 expected_results: Vec<EvaluationSearchResult>,
69}
70
71#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72struct EvaluationSearchResult {
73 file: String,
74 lines: RangeInclusive<u32>,
75}
76
77#[derive(Clone, Deserialize, Serialize)]
78struct EvaluationProjectOutcome {
79 repo: String,
80 sha: String,
81 queries: Vec<EvaluationQueryOutcome>,
82}
83
84#[derive(Clone, Debug, Deserialize, Serialize)]
85struct EvaluationQueryOutcome {
86 repo: String,
87 query: String,
88 expected_results: Vec<EvaluationSearchResult>,
89 actual_results: Vec<EvaluationSearchResult>,
90 covered_file_count: usize,
91 overlapped_result_count: usize,
92 covered_result_count: usize,
93 total_result_count: usize,
94 covered_result_indices: Vec<usize>,
95}
96
97fn main() -> Result<()> {
98 let cli = Cli::parse();
99 env_logger::init();
100
101 gpui::App::headless().run(move |cx| {
102 let executor = cx.background_executor().clone();
103 let client = isahc_http_client::IsahcHttpClient::new(None, None);
104 cx.set_http_client(client.clone());
105 match cli.command {
106 Commands::Fetch {} => {
107 executor
108 .clone()
109 .spawn(async move {
110 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
111 eprintln!("Error: {}", err);
112 exit(1);
113 }
114 exit(0);
115 })
116 .detach();
117 }
118 Commands::Run { repo } => {
119 cx.spawn(|mut cx| async move {
120 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
121 eprintln!("Error: {}", err);
122 exit(1);
123 }
124 exit(0);
125 })
126 .detach();
127 }
128 }
129 });
130
131 Ok(())
132}
133
134async fn fetch_evaluation_resources(
135 http_client: Arc<dyn HttpClient>,
136 executor: &BackgroundExecutor,
137) -> Result<()> {
138 fetch_code_search_net_resources(&*http_client).await?;
139 fetch_eval_repos(executor, &*http_client).await?;
140 Ok(())
141}
142
143async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
144 eprintln!("Fetching CodeSearchNet evaluations...");
145
146 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
147
148 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
149 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
150
151 // Fetch the annotations CSV, which contains the human-annotated search relevances
152 let annotations_path = dataset_dir.join("annotations.csv");
153 let annotations_csv_content = if annotations_path.exists() {
154 fs::read_to_string(&annotations_path).expect("failed to read annotations")
155 } else {
156 let response = http_client
157 .get(annotations_url, Default::default(), true)
158 .await
159 .expect("failed to fetch annotations csv");
160 let mut body = String::new();
161 response
162 .into_body()
163 .read_to_string(&mut body)
164 .await
165 .expect("failed to read annotations.csv response");
166 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
167 body
168 };
169
170 // Parse the annotations CSV. Skip over queries with zero relevance.
171 let rows = annotations_csv_content.lines().filter_map(|line| {
172 let mut values = line.split(',');
173 let _language = values.next()?;
174 let query = values.next()?;
175 let github_url = values.next()?;
176 let score = values.next()?;
177
178 if score == "0" {
179 return None;
180 }
181
182 let url_path = github_url.strip_prefix("https://github.com/")?;
183 let (url_path, hash) = url_path.split_once('#')?;
184 let (repo_name, url_path) = url_path.split_once("/blob/")?;
185 let (sha, file_path) = url_path.split_once('/')?;
186 let line_range = if let Some((start, end)) = hash.split_once('-') {
187 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
188 } else {
189 let row = hash.strip_prefix("L")?.parse().ok()?;
190 row..=row
191 };
192 Some((repo_name, sha, query, file_path, line_range))
193 });
194
195 // Group the annotations by repo and sha.
196 let mut evaluations_by_repo = BTreeMap::new();
197 for (repo_name, sha, query, file_path, lines) in rows {
198 let evaluation_project = evaluations_by_repo
199 .entry((repo_name, sha))
200 .or_insert_with(|| EvaluationProject {
201 repo: repo_name.to_string(),
202 sha: sha.to_string(),
203 queries: Vec::new(),
204 });
205
206 let ix = evaluation_project
207 .queries
208 .iter()
209 .position(|entry| entry.query == query)
210 .unwrap_or_else(|| {
211 evaluation_project.queries.push(EvaluationQuery {
212 query: query.to_string(),
213 expected_results: Vec::new(),
214 });
215 evaluation_project.queries.len() - 1
216 });
217 let results = &mut evaluation_project.queries[ix].expected_results;
218 let result = EvaluationSearchResult {
219 file: file_path.to_string(),
220 lines,
221 };
222 if !results.contains(&result) {
223 results.push(result);
224 }
225 }
226
227 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
228 let evaluations_path = dataset_dir.join("evaluations.json");
229 fs::write(
230 &evaluations_path,
231 serde_json::to_vec_pretty(&evaluations).unwrap(),
232 )
233 .unwrap();
234
235 eprintln!(
236 "Fetched CodeSearchNet evaluations into {}",
237 evaluations_path.display()
238 );
239
240 Ok(())
241}
242
243#[derive(Default, Debug)]
244struct Counts {
245 covered_results: usize,
246 overlapped_results: usize,
247 covered_files: usize,
248 total_results: usize,
249}
250
251async fn run_evaluation(
252 only_repo: Option<String>,
253 executor: &BackgroundExecutor,
254 cx: &mut AsyncAppContext,
255) -> Result<()> {
256 let mut http_client = None;
257 cx.update(|cx| {
258 let mut store = SettingsStore::new(cx);
259 store
260 .set_default_settings(settings::default_settings().as_ref(), cx)
261 .unwrap();
262 cx.set_global(store);
263 client::init_settings(cx);
264 language::init(cx);
265 Project::init_settings(cx);
266 http_client = Some(cx.http_client());
267 cx.update_flags(false, vec![]);
268 })
269 .unwrap();
270 let http_client = http_client.unwrap();
271 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
272 let evaluations_path = dataset_dir.join("evaluations.json");
273 let repos_dir = Path::new(EVAL_REPOS_DIR);
274 let db_path = Path::new(EVAL_DB_PATH);
275 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
276 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
277 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
278 let clock = Arc::new(RealSystemClock);
279 let client = cx
280 .update(|cx| {
281 Client::new(
282 clock,
283 Arc::new(http_client::HttpClientWithUrl::new(
284 http_client.clone(),
285 "https://zed.dev",
286 None,
287 )),
288 cx,
289 )
290 })
291 .unwrap();
292 let user_store = cx
293 .new_model(|cx| UserStore::new(client.clone(), cx))
294 .unwrap();
295 let node_runtime = NodeRuntime::unavailable();
296
297 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
298 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
299
300 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
301 http_client.clone(),
302 OpenAiEmbeddingModel::TextEmbedding3Small,
303 open_ai::OPEN_AI_API_URL.to_string(),
304 api_key,
305 ));
306
307 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
308 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
309 .unwrap();
310
311 let mut counts = Counts::default();
312 eprint!("Running evals.");
313
314 let mut failures = Vec::new();
315
316 for evaluation_project in evaluations {
317 if only_repo
318 .as_ref()
319 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
320 {
321 continue;
322 }
323
324 eprint!("\r\x1B[2K");
325 eprint!(
326 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
327 counts.covered_results,
328 counts.total_results,
329 counts.overlapped_results,
330 counts.total_results,
331 counts.covered_files,
332 counts.total_results,
333 evaluation_project.repo
334 );
335
336 let repo_dir = repos_dir.join(&evaluation_project.repo);
337 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
338 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
339 continue;
340 }
341
342 let repo_db_path =
343 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
344
345 let project = cx
346 .update(|cx| {
347 Project::local(
348 client.clone(),
349 node_runtime.clone(),
350 user_store.clone(),
351 language_registry.clone(),
352 fs.clone(),
353 None,
354 cx,
355 )
356 })
357 .unwrap();
358
359 let repo = evaluation_project.repo.clone();
360 if let Err(err) = run_eval_project(
361 evaluation_project,
362 &user_store,
363 repo_db_path,
364 &repo_dir,
365 &mut counts,
366 project,
367 embedding_provider.clone(),
368 fs.clone(),
369 cx,
370 )
371 .await
372 {
373 eprintln!("{repo} eval failed with error: {:?}", err);
374
375 failures.push((repo, err));
376 }
377 }
378
379 eprintln!(
380 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
381 counts.covered_results,
382 counts.total_results,
383 counts.overlapped_results,
384 counts.total_results,
385 counts.covered_files,
386 counts.total_results,
387 failures.len(),
388 );
389
390 if failures.is_empty() {
391 Ok(())
392 } else {
393 eprintln!("Failures:\n");
394
395 for (index, (repo, failure)) in failures.iter().enumerate() {
396 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
397 }
398
399 Err(anyhow::anyhow!("Some evals failed."))
400 }
401}
402
403#[allow(clippy::too_many_arguments)]
404async fn run_eval_project(
405 evaluation_project: EvaluationProject,
406 user_store: &Model<UserStore>,
407 repo_db_path: PathBuf,
408 repo_dir: &Path,
409 counts: &mut Counts,
410 project: Model<Project>,
411 embedding_provider: Arc<dyn EmbeddingProvider>,
412 fs: Arc<dyn Fs>,
413 cx: &mut AsyncAppContext,
414) -> Result<(), anyhow::Error> {
415 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
416
417 let (worktree, _) = project
418 .update(cx, |project, cx| {
419 project.find_or_create_worktree(repo_dir, true, cx)
420 })?
421 .await?;
422
423 worktree
424 .update(cx, |worktree, _| {
425 worktree.as_local().unwrap().scan_complete()
426 })?
427 .await;
428
429 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
430 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
431
432 for query in evaluation_project.queries {
433 let results = {
434 // Retry search up to 3 times in case of timeout, network failure, etc.
435 let mut retries_remaining = 3;
436 let mut result;
437
438 loop {
439 match cx.update(|cx| {
440 let project_index = project_index.read(cx);
441 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
442 }) {
443 Ok(task) => match task.await {
444 Ok(answer) => {
445 result = Ok(answer);
446 break;
447 }
448 Err(err) => {
449 result = Err(err);
450 }
451 },
452 Err(err) => {
453 result = Err(err);
454 }
455 }
456
457 if retries_remaining > 0 {
458 eprintln!(
459 "Retrying search after it failed on query {:?} with {:?}",
460 query, result
461 );
462 retries_remaining -= 1;
463 } else {
464 eprintln!(
465 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
466 query, result
467 );
468 break;
469 }
470 }
471
472 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
473 };
474
475 let mut project_covered_result_count = 0;
476 let mut project_overlapped_result_count = 0;
477 let mut project_covered_file_count = 0;
478 let mut covered_result_indices = Vec::new();
479 for expected_result in &query.expected_results {
480 let mut file_matched = false;
481 let mut range_overlapped = false;
482 let mut range_covered = false;
483
484 for (ix, result) in results.iter().enumerate() {
485 if result.path.as_ref() == Path::new(&expected_result.file) {
486 file_matched = true;
487 let start_matched = result.row_range.contains(&expected_result.lines.start());
488 let end_matched = result.row_range.contains(&expected_result.lines.end());
489
490 if start_matched || end_matched {
491 range_overlapped = true;
492 }
493
494 if start_matched && end_matched {
495 range_covered = true;
496 covered_result_indices.push(ix);
497 break;
498 }
499 }
500 }
501
502 if range_covered {
503 project_covered_result_count += 1
504 };
505 if range_overlapped {
506 project_overlapped_result_count += 1
507 };
508 if file_matched {
509 project_covered_file_count += 1
510 };
511 }
512 let outcome_repo = evaluation_project.repo.clone();
513
514 let query_results = EvaluationQueryOutcome {
515 repo: outcome_repo,
516 query: query.query,
517 total_result_count: query.expected_results.len(),
518 covered_result_count: project_covered_result_count,
519 overlapped_result_count: project_overlapped_result_count,
520 covered_file_count: project_covered_file_count,
521 expected_results: query.expected_results,
522 actual_results: results
523 .iter()
524 .map(|result| EvaluationSearchResult {
525 file: result.path.to_string_lossy().to_string(),
526 lines: result.row_range.clone(),
527 })
528 .collect(),
529 covered_result_indices,
530 };
531
532 counts.overlapped_results += query_results.overlapped_result_count;
533 counts.covered_results += query_results.covered_result_count;
534 counts.covered_files += query_results.covered_file_count;
535 counts.total_results += query_results.total_result_count;
536
537 println!("{}", serde_json::to_string(&query_results)?);
538 }
539
540 user_store.update(cx, |_, _| {
541 drop(semantic_index);
542 drop(project);
543 drop(worktree);
544 drop(project_index);
545 })
546}
547
548async fn wait_for_indexing_complete(
549 project_index: &Model<ProjectIndex>,
550 cx: &mut AsyncAppContext,
551 timeout: Option<Duration>,
552) {
553 let (tx, rx) = bounded(1);
554 let subscription = cx.update(|cx| {
555 cx.subscribe(project_index, move |_, event, _| {
556 if let Status::Idle = event {
557 let _ = tx.try_send(*event);
558 }
559 })
560 });
561
562 let result = match timeout {
563 Some(timeout_duration) => {
564 smol::future::or(
565 async {
566 rx.recv().await.map_err(|_| ())?;
567 Ok(())
568 },
569 async {
570 Timer::after(timeout_duration).await;
571 Err(())
572 },
573 )
574 .await
575 }
576 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
577 };
578
579 match result {
580 Ok(_) => (),
581 Err(_) => {
582 if let Some(timeout) = timeout {
583 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
584 }
585 }
586 }
587
588 drop(subscription);
589}
590
591async fn fetch_eval_repos(
592 executor: &BackgroundExecutor,
593 http_client: &dyn HttpClient,
594) -> Result<()> {
595 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
596 let evaluations_path = dataset_dir.join("evaluations.json");
597 let repos_dir = Path::new(EVAL_REPOS_DIR);
598
599 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
600 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
601
602 eprintln!("Fetching evaluation repositories...");
603
604 executor
605 .scoped(move |scope| {
606 let done_count = Arc::new(AtomicUsize::new(0));
607 let len = evaluations.len();
608 for chunk in evaluations.chunks(evaluations.len() / 8) {
609 let chunk = chunk.to_vec();
610 let done_count = done_count.clone();
611 scope.spawn(async move {
612 for EvaluationProject { repo, sha, .. } in chunk {
613 eprint!(
614 "\rFetching evaluation repositories ({}/{})...",
615 done_count.load(SeqCst),
616 len,
617 );
618
619 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
620 done_count.fetch_add(1, SeqCst);
621 }
622 });
623 }
624 })
625 .await;
626
627 Ok(())
628}
629
630async fn fetch_eval_repo(
631 repo: String,
632 sha: String,
633 repos_dir: &Path,
634 http_client: &dyn HttpClient,
635) {
636 let Some((owner, repo_name)) = repo.split_once('/') else {
637 return;
638 };
639 let repo_dir = repos_dir.join(owner).join(repo_name);
640 fs::create_dir_all(&repo_dir).unwrap();
641 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
642 if skip_eval_path.exists() {
643 return;
644 }
645 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
646 if head_content.trim() == sha {
647 return;
648 }
649 }
650 let repo_response = http_client
651 .send(
652 http_client::Request::builder()
653 .method(Method::HEAD)
654 .uri(format!("https://github.com/{}", repo))
655 .body(Default::default())
656 .expect(""),
657 )
658 .await
659 .expect("failed to check github repo");
660 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
661 fs::write(&skip_eval_path, "").unwrap();
662 eprintln!(
663 "Repo {repo} is no longer public ({:?}). Skipping",
664 repo_response.status()
665 );
666 return;
667 }
668 if !repo_dir.join(".git").exists() {
669 let init_output = Command::new("git")
670 .current_dir(&repo_dir)
671 .args(&["init"])
672 .output()
673 .unwrap();
674 if !init_output.status.success() {
675 eprintln!(
676 "Failed to initialize git repository for {}: {}",
677 repo,
678 String::from_utf8_lossy(&init_output.stderr)
679 );
680 return;
681 }
682 }
683 let url = format!("https://github.com/{}.git", repo);
684 Command::new("git")
685 .current_dir(&repo_dir)
686 .args(&["remote", "add", "-f", "origin", &url])
687 .stdin(Stdio::null())
688 .output()
689 .unwrap();
690 let fetch_output = Command::new("git")
691 .current_dir(&repo_dir)
692 .args(&["fetch", "--depth", "1", "origin", &sha])
693 .stdin(Stdio::null())
694 .output()
695 .unwrap();
696 if !fetch_output.status.success() {
697 eprintln!(
698 "Failed to fetch {} for {}: {}",
699 sha,
700 repo,
701 String::from_utf8_lossy(&fetch_output.stderr)
702 );
703 return;
704 }
705 let checkout_output = Command::new("git")
706 .current_dir(&repo_dir)
707 .args(&["checkout", &sha])
708 .output()
709 .unwrap();
710
711 if !checkout_output.status.success() {
712 eprintln!(
713 "Failed to checkout {} for {}: {}",
714 sha,
715 repo,
716 String::from_utf8_lossy(&checkout_output.stderr)
717 );
718 }
719}