1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::NodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use reqwest_client::ReqwestClient;
16use semantic_index::{
17 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
18};
19use serde::{Deserialize, Serialize};
20use settings::SettingsStore;
21use smol::channel::bounded;
22use smol::io::AsyncReadExt;
23use smol::Timer;
24use std::ops::RangeInclusive;
25use std::path::PathBuf;
26use std::time::Duration;
27use std::{
28 fs,
29 path::Path,
30 process::{exit, Stdio},
31 sync::{
32 atomic::{AtomicUsize, Ordering::SeqCst},
33 Arc,
34 },
35};
36
37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
39const EVAL_DB_PATH: &'static str = "target/eval_db";
40const SEARCH_RESULT_LIMIT: usize = 8;
41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
42
43#[derive(clap::Parser)]
44#[command(author, version, about, long_about = None)]
45struct Cli {
46 #[command(subcommand)]
47 command: Commands,
48}
49
50#[derive(clap::Subcommand)]
51enum Commands {
52 Fetch {},
53 Run {
54 #[arg(long)]
55 repo: Option<String>,
56 },
57}
58
59#[derive(Clone, Deserialize, Serialize)]
60struct EvaluationProject {
61 repo: String,
62 sha: String,
63 queries: Vec<EvaluationQuery>,
64}
65
66#[derive(Clone, Debug, Deserialize, Serialize)]
67struct EvaluationQuery {
68 query: String,
69 expected_results: Vec<EvaluationSearchResult>,
70}
71
72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
73struct EvaluationSearchResult {
74 file: String,
75 lines: RangeInclusive<u32>,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79struct EvaluationProjectOutcome {
80 repo: String,
81 sha: String,
82 queries: Vec<EvaluationQueryOutcome>,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86struct EvaluationQueryOutcome {
87 repo: String,
88 query: String,
89 expected_results: Vec<EvaluationSearchResult>,
90 actual_results: Vec<EvaluationSearchResult>,
91 covered_file_count: usize,
92 overlapped_result_count: usize,
93 covered_result_count: usize,
94 total_result_count: usize,
95 covered_result_indices: Vec<usize>,
96}
97
98fn main() -> Result<()> {
99 let cli = Cli::parse();
100 env_logger::init();
101
102 gpui::App::headless().run(move |cx| {
103 let executor = cx.background_executor().clone();
104 let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
105 cx.set_http_client(client.clone());
106 match cli.command {
107 Commands::Fetch {} => {
108 executor
109 .clone()
110 .spawn(async move {
111 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
112 eprintln!("Error: {}", err);
113 exit(1);
114 }
115 exit(0);
116 })
117 .detach();
118 }
119 Commands::Run { repo } => {
120 cx.spawn(|mut cx| async move {
121 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
122 eprintln!("Error: {}", err);
123 exit(1);
124 }
125 exit(0);
126 })
127 .detach();
128 }
129 }
130 });
131
132 Ok(())
133}
134
135async fn fetch_evaluation_resources(
136 http_client: Arc<dyn HttpClient>,
137 executor: &BackgroundExecutor,
138) -> Result<()> {
139 fetch_code_search_net_resources(&*http_client).await?;
140 fetch_eval_repos(executor, &*http_client).await?;
141 Ok(())
142}
143
144async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
145 eprintln!("Fetching CodeSearchNet evaluations...");
146
147 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
148
149 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
150 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
151
152 // Fetch the annotations CSV, which contains the human-annotated search relevances
153 let annotations_path = dataset_dir.join("annotations.csv");
154 let annotations_csv_content = if annotations_path.exists() {
155 fs::read_to_string(&annotations_path).expect("failed to read annotations")
156 } else {
157 let response = http_client
158 .get(annotations_url, Default::default(), true)
159 .await
160 .expect("failed to fetch annotations csv");
161 let mut body = String::new();
162 response
163 .into_body()
164 .read_to_string(&mut body)
165 .await
166 .expect("failed to read annotations.csv response");
167 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
168 body
169 };
170
171 // Parse the annotations CSV. Skip over queries with zero relevance.
172 let rows = annotations_csv_content.lines().filter_map(|line| {
173 let mut values = line.split(',');
174 let _language = values.next()?;
175 let query = values.next()?;
176 let github_url = values.next()?;
177 let score = values.next()?;
178
179 if score == "0" {
180 return None;
181 }
182
183 let url_path = github_url.strip_prefix("https://github.com/")?;
184 let (url_path, hash) = url_path.split_once('#')?;
185 let (repo_name, url_path) = url_path.split_once("/blob/")?;
186 let (sha, file_path) = url_path.split_once('/')?;
187 let line_range = if let Some((start, end)) = hash.split_once('-') {
188 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
189 } else {
190 let row = hash.strip_prefix("L")?.parse().ok()?;
191 row..=row
192 };
193 Some((repo_name, sha, query, file_path, line_range))
194 });
195
196 // Group the annotations by repo and sha.
197 let mut evaluations_by_repo = BTreeMap::new();
198 for (repo_name, sha, query, file_path, lines) in rows {
199 let evaluation_project = evaluations_by_repo
200 .entry((repo_name, sha))
201 .or_insert_with(|| EvaluationProject {
202 repo: repo_name.to_string(),
203 sha: sha.to_string(),
204 queries: Vec::new(),
205 });
206
207 let ix = evaluation_project
208 .queries
209 .iter()
210 .position(|entry| entry.query == query)
211 .unwrap_or_else(|| {
212 evaluation_project.queries.push(EvaluationQuery {
213 query: query.to_string(),
214 expected_results: Vec::new(),
215 });
216 evaluation_project.queries.len() - 1
217 });
218 let results = &mut evaluation_project.queries[ix].expected_results;
219 let result = EvaluationSearchResult {
220 file: file_path.to_string(),
221 lines,
222 };
223 if !results.contains(&result) {
224 results.push(result);
225 }
226 }
227
228 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
229 let evaluations_path = dataset_dir.join("evaluations.json");
230 fs::write(
231 &evaluations_path,
232 serde_json::to_vec_pretty(&evaluations).unwrap(),
233 )
234 .unwrap();
235
236 eprintln!(
237 "Fetched CodeSearchNet evaluations into {}",
238 evaluations_path.display()
239 );
240
241 Ok(())
242}
243
244#[derive(Default, Debug)]
245struct Counts {
246 covered_results: usize,
247 overlapped_results: usize,
248 covered_files: usize,
249 total_results: usize,
250}
251
252async fn run_evaluation(
253 only_repo: Option<String>,
254 executor: &BackgroundExecutor,
255 cx: &mut AsyncAppContext,
256) -> Result<()> {
257 let mut http_client = None;
258 cx.update(|cx| {
259 let mut store = SettingsStore::new(cx);
260 store
261 .set_default_settings(settings::default_settings().as_ref(), cx)
262 .unwrap();
263 cx.set_global(store);
264 client::init_settings(cx);
265 language::init(cx);
266 Project::init_settings(cx);
267 http_client = Some(cx.http_client());
268 cx.update_flags(false, vec![]);
269 })
270 .unwrap();
271 let http_client = http_client.unwrap();
272 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
273 let evaluations_path = dataset_dir.join("evaluations.json");
274 let repos_dir = Path::new(EVAL_REPOS_DIR);
275 let db_path = Path::new(EVAL_DB_PATH);
276 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
277 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
278 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
279 let clock = Arc::new(RealSystemClock);
280 let client = cx
281 .update(|cx| {
282 Client::new(
283 clock,
284 Arc::new(http_client::HttpClientWithUrl::new(
285 http_client.clone(),
286 "https://zed.dev",
287 None,
288 )),
289 cx,
290 )
291 })
292 .unwrap();
293 let user_store = cx
294 .new_model(|cx| UserStore::new(client.clone(), cx))
295 .unwrap();
296 let node_runtime = NodeRuntime::unavailable();
297
298 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
299 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
300
301 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
302 http_client.clone(),
303 OpenAiEmbeddingModel::TextEmbedding3Small,
304 open_ai::OPEN_AI_API_URL.to_string(),
305 api_key,
306 ));
307
308 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
309 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
310 .unwrap();
311
312 let mut counts = Counts::default();
313 eprint!("Running evals.");
314
315 let mut failures = Vec::new();
316
317 for evaluation_project in evaluations {
318 if only_repo
319 .as_ref()
320 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
321 {
322 continue;
323 }
324
325 eprint!("\r\x1B[2K");
326 eprint!(
327 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
328 counts.covered_results,
329 counts.total_results,
330 counts.overlapped_results,
331 counts.total_results,
332 counts.covered_files,
333 counts.total_results,
334 evaluation_project.repo
335 );
336
337 let repo_dir = repos_dir.join(&evaluation_project.repo);
338 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
339 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
340 continue;
341 }
342
343 let repo_db_path =
344 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
345
346 let project = cx
347 .update(|cx| {
348 Project::local(
349 client.clone(),
350 node_runtime.clone(),
351 user_store.clone(),
352 language_registry.clone(),
353 fs.clone(),
354 None,
355 cx,
356 )
357 })
358 .unwrap();
359
360 let repo = evaluation_project.repo.clone();
361 if let Err(err) = run_eval_project(
362 evaluation_project,
363 &user_store,
364 repo_db_path,
365 &repo_dir,
366 &mut counts,
367 project,
368 embedding_provider.clone(),
369 fs.clone(),
370 cx,
371 )
372 .await
373 {
374 eprintln!("{repo} eval failed with error: {:?}", err);
375
376 failures.push((repo, err));
377 }
378 }
379
380 eprintln!(
381 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
382 counts.covered_results,
383 counts.total_results,
384 counts.overlapped_results,
385 counts.total_results,
386 counts.covered_files,
387 counts.total_results,
388 failures.len(),
389 );
390
391 if failures.is_empty() {
392 Ok(())
393 } else {
394 eprintln!("Failures:\n");
395
396 for (index, (repo, failure)) in failures.iter().enumerate() {
397 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
398 }
399
400 Err(anyhow::anyhow!("Some evals failed."))
401 }
402}
403
404#[allow(clippy::too_many_arguments)]
405async fn run_eval_project(
406 evaluation_project: EvaluationProject,
407 user_store: &Model<UserStore>,
408 repo_db_path: PathBuf,
409 repo_dir: &Path,
410 counts: &mut Counts,
411 project: Model<Project>,
412 embedding_provider: Arc<dyn EmbeddingProvider>,
413 fs: Arc<dyn Fs>,
414 cx: &mut AsyncAppContext,
415) -> Result<(), anyhow::Error> {
416 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
417
418 let (worktree, _) = project
419 .update(cx, |project, cx| {
420 project.find_or_create_worktree(repo_dir, true, cx)
421 })?
422 .await?;
423
424 worktree
425 .update(cx, |worktree, _| {
426 worktree.as_local().unwrap().scan_complete()
427 })?
428 .await;
429
430 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
431 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
432
433 for query in evaluation_project.queries {
434 let results = {
435 // Retry search up to 3 times in case of timeout, network failure, etc.
436 let mut retries_remaining = 3;
437 let mut result;
438
439 loop {
440 match cx.update(|cx| {
441 let project_index = project_index.read(cx);
442 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
443 }) {
444 Ok(task) => match task.await {
445 Ok(answer) => {
446 result = Ok(answer);
447 break;
448 }
449 Err(err) => {
450 result = Err(err);
451 }
452 },
453 Err(err) => {
454 result = Err(err);
455 }
456 }
457
458 if retries_remaining > 0 {
459 eprintln!(
460 "Retrying search after it failed on query {:?} with {:?}",
461 query, result
462 );
463 retries_remaining -= 1;
464 } else {
465 eprintln!(
466 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
467 query, result
468 );
469 break;
470 }
471 }
472
473 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
474 };
475
476 let mut project_covered_result_count = 0;
477 let mut project_overlapped_result_count = 0;
478 let mut project_covered_file_count = 0;
479 let mut covered_result_indices = Vec::new();
480 for expected_result in &query.expected_results {
481 let mut file_matched = false;
482 let mut range_overlapped = false;
483 let mut range_covered = false;
484
485 for (ix, result) in results.iter().enumerate() {
486 if result.path.as_ref() == Path::new(&expected_result.file) {
487 file_matched = true;
488 let start_matched = result.row_range.contains(&expected_result.lines.start());
489 let end_matched = result.row_range.contains(&expected_result.lines.end());
490
491 if start_matched || end_matched {
492 range_overlapped = true;
493 }
494
495 if start_matched && end_matched {
496 range_covered = true;
497 covered_result_indices.push(ix);
498 break;
499 }
500 }
501 }
502
503 if range_covered {
504 project_covered_result_count += 1
505 };
506 if range_overlapped {
507 project_overlapped_result_count += 1
508 };
509 if file_matched {
510 project_covered_file_count += 1
511 };
512 }
513 let outcome_repo = evaluation_project.repo.clone();
514
515 let query_results = EvaluationQueryOutcome {
516 repo: outcome_repo,
517 query: query.query,
518 total_result_count: query.expected_results.len(),
519 covered_result_count: project_covered_result_count,
520 overlapped_result_count: project_overlapped_result_count,
521 covered_file_count: project_covered_file_count,
522 expected_results: query.expected_results,
523 actual_results: results
524 .iter()
525 .map(|result| EvaluationSearchResult {
526 file: result.path.to_string_lossy().to_string(),
527 lines: result.row_range.clone(),
528 })
529 .collect(),
530 covered_result_indices,
531 };
532
533 counts.overlapped_results += query_results.overlapped_result_count;
534 counts.covered_results += query_results.covered_result_count;
535 counts.covered_files += query_results.covered_file_count;
536 counts.total_results += query_results.total_result_count;
537
538 println!("{}", serde_json::to_string(&query_results)?);
539 }
540
541 user_store.update(cx, |_, _| {
542 drop(semantic_index);
543 drop(project);
544 drop(worktree);
545 drop(project_index);
546 })
547}
548
549async fn wait_for_indexing_complete(
550 project_index: &Model<ProjectIndex>,
551 cx: &mut AsyncAppContext,
552 timeout: Option<Duration>,
553) {
554 let (tx, rx) = bounded(1);
555 let subscription = cx.update(|cx| {
556 cx.subscribe(project_index, move |_, event, _| {
557 if let Status::Idle = event {
558 let _ = tx.try_send(*event);
559 }
560 })
561 });
562
563 let result = match timeout {
564 Some(timeout_duration) => {
565 smol::future::or(
566 async {
567 rx.recv().await.map_err(|_| ())?;
568 Ok(())
569 },
570 async {
571 Timer::after(timeout_duration).await;
572 Err(())
573 },
574 )
575 .await
576 }
577 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
578 };
579
580 match result {
581 Ok(_) => (),
582 Err(_) => {
583 if let Some(timeout) = timeout {
584 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
585 }
586 }
587 }
588
589 drop(subscription);
590}
591
592async fn fetch_eval_repos(
593 executor: &BackgroundExecutor,
594 http_client: &dyn HttpClient,
595) -> Result<()> {
596 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
597 let evaluations_path = dataset_dir.join("evaluations.json");
598 let repos_dir = Path::new(EVAL_REPOS_DIR);
599
600 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
601 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
602
603 eprintln!("Fetching evaluation repositories...");
604
605 executor
606 .scoped(move |scope| {
607 let done_count = Arc::new(AtomicUsize::new(0));
608 let len = evaluations.len();
609 for chunk in evaluations.chunks(evaluations.len() / 8) {
610 let chunk = chunk.to_vec();
611 let done_count = done_count.clone();
612 scope.spawn(async move {
613 for EvaluationProject { repo, sha, .. } in chunk {
614 eprint!(
615 "\rFetching evaluation repositories ({}/{})...",
616 done_count.load(SeqCst),
617 len,
618 );
619
620 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
621 done_count.fetch_add(1, SeqCst);
622 }
623 });
624 }
625 })
626 .await;
627
628 Ok(())
629}
630
631async fn fetch_eval_repo(
632 repo: String,
633 sha: String,
634 repos_dir: &Path,
635 http_client: &dyn HttpClient,
636) {
637 let Some((owner, repo_name)) = repo.split_once('/') else {
638 return;
639 };
640 let repo_dir = repos_dir.join(owner).join(repo_name);
641 fs::create_dir_all(&repo_dir).unwrap();
642 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
643 if skip_eval_path.exists() {
644 return;
645 }
646 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
647 if head_content.trim() == sha {
648 return;
649 }
650 }
651 let repo_response = http_client
652 .send(
653 http_client::Request::builder()
654 .method(Method::HEAD)
655 .uri(format!("https://github.com/{}", repo))
656 .body(Default::default())
657 .expect(""),
658 )
659 .await
660 .expect("failed to check github repo");
661 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
662 fs::write(&skip_eval_path, "").unwrap();
663 eprintln!(
664 "Repo {repo} is no longer public ({:?}). Skipping",
665 repo_response.status()
666 );
667 return;
668 }
669 if !repo_dir.join(".git").exists() {
670 let init_output = util::command::new_std_command("git")
671 .current_dir(&repo_dir)
672 .args(&["init"])
673 .output()
674 .unwrap();
675 if !init_output.status.success() {
676 eprintln!(
677 "Failed to initialize git repository for {}: {}",
678 repo,
679 String::from_utf8_lossy(&init_output.stderr)
680 );
681 return;
682 }
683 }
684 let url = format!("https://github.com/{}.git", repo);
685 util::command::new_std_command("git")
686 .current_dir(&repo_dir)
687 .args(&["remote", "add", "-f", "origin", &url])
688 .stdin(Stdio::null())
689 .output()
690 .unwrap();
691 let fetch_output = util::command::new_std_command("git")
692 .current_dir(&repo_dir)
693 .args(&["fetch", "--depth", "1", "origin", &sha])
694 .stdin(Stdio::null())
695 .output()
696 .unwrap();
697 if !fetch_output.status.success() {
698 eprintln!(
699 "Failed to fetch {} for {}: {}",
700 sha,
701 repo,
702 String::from_utf8_lossy(&fetch_output.stderr)
703 );
704 return;
705 }
706 let checkout_output = util::command::new_std_command("git")
707 .current_dir(&repo_dir)
708 .args(&["checkout", &sha])
709 .output()
710 .unwrap();
711
712 if !checkout_output.status.success() {
713 eprintln!(
714 "Failed to checkout {} for {}: {}",
715 sha,
716 repo,
717 String::from_utf8_lossy(&checkout_output.stderr)
718 );
719 }
720}