1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use dap::DapRegistry;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::NodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use reqwest_client::ReqwestClient;
16use semantic_index::{
17 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
18};
19use serde::{Deserialize, Serialize};
20use settings::SettingsStore;
21use smol::Timer;
22use smol::channel::bounded;
23use smol::io::AsyncReadExt;
24use std::ops::RangeInclusive;
25use std::path::PathBuf;
26use std::time::Duration;
27use std::{
28 fs,
29 path::Path,
30 process::{Stdio, exit},
31 sync::{
32 Arc,
33 atomic::{AtomicUsize, Ordering::SeqCst},
34 },
35};
36
37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
39const EVAL_DB_PATH: &'static str = "target/eval_db";
40const SEARCH_RESULT_LIMIT: usize = 8;
41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
42
43#[derive(clap::Parser)]
44#[command(author, version, about, long_about = None)]
45struct Cli {
46 #[command(subcommand)]
47 command: Commands,
48}
49
50#[derive(clap::Subcommand)]
51enum Commands {
52 Fetch {},
53 Run {
54 #[arg(long)]
55 repo: Option<String>,
56 },
57}
58
59#[derive(Clone, Deserialize, Serialize)]
60struct EvaluationProject {
61 repo: String,
62 sha: String,
63 queries: Vec<EvaluationQuery>,
64}
65
66#[derive(Clone, Debug, Deserialize, Serialize)]
67struct EvaluationQuery {
68 query: String,
69 expected_results: Vec<EvaluationSearchResult>,
70}
71
72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
73struct EvaluationSearchResult {
74 file: String,
75 lines: RangeInclusive<u32>,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79struct EvaluationProjectOutcome {
80 repo: String,
81 sha: String,
82 queries: Vec<EvaluationQueryOutcome>,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86struct EvaluationQueryOutcome {
87 repo: String,
88 query: String,
89 expected_results: Vec<EvaluationSearchResult>,
90 actual_results: Vec<EvaluationSearchResult>,
91 covered_file_count: usize,
92 overlapped_result_count: usize,
93 covered_result_count: usize,
94 total_result_count: usize,
95 covered_result_indices: Vec<usize>,
96}
97
98fn main() -> Result<()> {
99 let cli = Cli::parse();
100 env_logger::init();
101
102 gpui::Application::headless().run(move |cx| {
103 let executor = cx.background_executor().clone();
104 let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
105 cx.set_http_client(client.clone());
106 match cli.command {
107 Commands::Fetch {} => {
108 executor
109 .clone()
110 .spawn(async move {
111 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
112 eprintln!("Error: {}", err);
113 exit(1);
114 }
115 exit(0);
116 })
117 .detach();
118 }
119 Commands::Run { repo } => {
120 cx.spawn(async move |cx| {
121 if let Err(err) = run_evaluation(repo, &executor, cx).await {
122 eprintln!("Error: {}", err);
123 exit(1);
124 }
125 exit(0);
126 })
127 .detach();
128 }
129 }
130 });
131
132 Ok(())
133}
134
135async fn fetch_evaluation_resources(
136 http_client: Arc<dyn HttpClient>,
137 executor: &BackgroundExecutor,
138) -> Result<()> {
139 fetch_code_search_net_resources(&*http_client).await?;
140 fetch_eval_repos(executor, &*http_client).await?;
141 Ok(())
142}
143
144async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
145 eprintln!("Fetching CodeSearchNet evaluations...");
146
147 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
148
149 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
150 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
151
152 // Fetch the annotations CSV, which contains the human-annotated search relevances
153 let annotations_path = dataset_dir.join("annotations.csv");
154 let annotations_csv_content = if annotations_path.exists() {
155 fs::read_to_string(&annotations_path).expect("failed to read annotations")
156 } else {
157 let response = http_client
158 .get(annotations_url, Default::default(), true)
159 .await
160 .expect("failed to fetch annotations csv");
161 let mut body = String::new();
162 response
163 .into_body()
164 .read_to_string(&mut body)
165 .await
166 .expect("failed to read annotations.csv response");
167 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
168 body
169 };
170
171 // Parse the annotations CSV. Skip over queries with zero relevance.
172 let rows = annotations_csv_content.lines().filter_map(|line| {
173 let mut values = line.split(',');
174 let _language = values.next()?;
175 let query = values.next()?;
176 let github_url = values.next()?;
177 let score = values.next()?;
178
179 if score == "0" {
180 return None;
181 }
182
183 let url_path = github_url.strip_prefix("https://github.com/")?;
184 let (url_path, hash) = url_path.split_once('#')?;
185 let (repo_name, url_path) = url_path.split_once("/blob/")?;
186 let (sha, file_path) = url_path.split_once('/')?;
187 let line_range = if let Some((start, end)) = hash.split_once('-') {
188 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
189 } else {
190 let row = hash.strip_prefix("L")?.parse().ok()?;
191 row..=row
192 };
193 Some((repo_name, sha, query, file_path, line_range))
194 });
195
196 // Group the annotations by repo and sha.
197 let mut evaluations_by_repo = BTreeMap::new();
198 for (repo_name, sha, query, file_path, lines) in rows {
199 let evaluation_project = evaluations_by_repo
200 .entry((repo_name, sha))
201 .or_insert_with(|| EvaluationProject {
202 repo: repo_name.to_string(),
203 sha: sha.to_string(),
204 queries: Vec::new(),
205 });
206
207 let ix = evaluation_project
208 .queries
209 .iter()
210 .position(|entry| entry.query == query)
211 .unwrap_or_else(|| {
212 evaluation_project.queries.push(EvaluationQuery {
213 query: query.to_string(),
214 expected_results: Vec::new(),
215 });
216 evaluation_project.queries.len() - 1
217 });
218 let results = &mut evaluation_project.queries[ix].expected_results;
219 let result = EvaluationSearchResult {
220 file: file_path.to_string(),
221 lines,
222 };
223 if !results.contains(&result) {
224 results.push(result);
225 }
226 }
227
228 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
229 let evaluations_path = dataset_dir.join("evaluations.json");
230 fs::write(
231 &evaluations_path,
232 serde_json::to_vec_pretty(&evaluations).unwrap(),
233 )
234 .unwrap();
235
236 eprintln!(
237 "Fetched CodeSearchNet evaluations into {}",
238 evaluations_path.display()
239 );
240
241 Ok(())
242}
243
244#[derive(Default, Debug)]
245struct Counts {
246 covered_results: usize,
247 overlapped_results: usize,
248 covered_files: usize,
249 total_results: usize,
250}
251
252async fn run_evaluation(
253 only_repo: Option<String>,
254 executor: &BackgroundExecutor,
255 cx: &mut AsyncApp,
256) -> Result<()> {
257 let mut http_client = None;
258 cx.update(|cx| {
259 let mut store = SettingsStore::new(cx);
260 store
261 .set_default_settings(settings::default_settings().as_ref(), cx)
262 .unwrap();
263 cx.set_global(store);
264 client::init_settings(cx);
265 language::init(cx);
266 Project::init_settings(cx);
267 http_client = Some(cx.http_client());
268 cx.update_flags(false, vec![]);
269 })
270 .unwrap();
271 let http_client = http_client.unwrap();
272 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
273 let evaluations_path = dataset_dir.join("evaluations.json");
274 let repos_dir = Path::new(EVAL_REPOS_DIR);
275 let db_path = Path::new(EVAL_DB_PATH);
276 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
277 let fs = Arc::new(RealFs::new(None, cx.background_executor().clone())) as Arc<dyn Fs>;
278 let clock = Arc::new(RealSystemClock);
279 let client = cx
280 .update(|cx| {
281 Client::new(
282 clock,
283 Arc::new(http_client::HttpClientWithUrl::new(
284 http_client.clone(),
285 "https://zed.dev",
286 None,
287 )),
288 cx,
289 )
290 })
291 .unwrap();
292 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
293 let node_runtime = NodeRuntime::unavailable();
294
295 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
296 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
297
298 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
299 http_client.clone(),
300 OpenAiEmbeddingModel::TextEmbedding3Small,
301 open_ai::OPEN_AI_API_URL.to_string(),
302 api_key,
303 ));
304
305 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
306 let debug_adapters = Arc::new(DapRegistry::default());
307 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
308 .unwrap();
309
310 let mut counts = Counts::default();
311 eprint!("Running evals.");
312
313 let mut failures = Vec::new();
314
315 for evaluation_project in evaluations {
316 if only_repo
317 .as_ref()
318 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
319 {
320 continue;
321 }
322
323 eprint!("\r\x1B[2K");
324 eprint!(
325 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
326 counts.covered_results,
327 counts.total_results,
328 counts.overlapped_results,
329 counts.total_results,
330 counts.covered_files,
331 counts.total_results,
332 evaluation_project.repo
333 );
334
335 let repo_dir = repos_dir.join(&evaluation_project.repo);
336 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
337 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
338 continue;
339 }
340
341 let repo_db_path =
342 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
343
344 let project = cx
345 .update(|cx| {
346 Project::local(
347 client.clone(),
348 node_runtime.clone(),
349 user_store.clone(),
350 language_registry.clone(),
351 debug_adapters.clone(),
352 fs.clone(),
353 None,
354 cx,
355 )
356 })
357 .unwrap();
358
359 let repo = evaluation_project.repo.clone();
360 if let Err(err) = run_eval_project(
361 evaluation_project,
362 &user_store,
363 repo_db_path,
364 &repo_dir,
365 &mut counts,
366 project,
367 embedding_provider.clone(),
368 fs.clone(),
369 cx,
370 )
371 .await
372 {
373 eprintln!("{repo} eval failed with error: {:?}", err);
374
375 failures.push((repo, err));
376 }
377 }
378
379 eprintln!(
380 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
381 counts.covered_results,
382 counts.total_results,
383 counts.overlapped_results,
384 counts.total_results,
385 counts.covered_files,
386 counts.total_results,
387 failures.len(),
388 );
389
390 if failures.is_empty() {
391 Ok(())
392 } else {
393 eprintln!("Failures:\n");
394
395 for (index, (repo, failure)) in failures.iter().enumerate() {
396 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
397 }
398
399 Err(anyhow::anyhow!("Some evals failed."))
400 }
401}
402
403async fn run_eval_project(
404 evaluation_project: EvaluationProject,
405 user_store: &Entity<UserStore>,
406 repo_db_path: PathBuf,
407 repo_dir: &Path,
408 counts: &mut Counts,
409 project: Entity<Project>,
410 embedding_provider: Arc<dyn EmbeddingProvider>,
411 fs: Arc<dyn Fs>,
412 cx: &mut AsyncApp,
413) -> Result<(), anyhow::Error> {
414 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
415
416 let (worktree, _) = project
417 .update(cx, |project, cx| {
418 project.find_or_create_worktree(repo_dir, true, cx)
419 })?
420 .await?;
421
422 worktree
423 .update(cx, |worktree, _| {
424 worktree.as_local().unwrap().scan_complete()
425 })?
426 .await;
427
428 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
429 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
430
431 for query in evaluation_project.queries {
432 let results = {
433 // Retry search up to 3 times in case of timeout, network failure, etc.
434 let mut retries_remaining = 3;
435 let mut result;
436
437 loop {
438 match cx.update(|cx| {
439 let project_index = project_index.read(cx);
440 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
441 }) {
442 Ok(task) => match task.await {
443 Ok(answer) => {
444 result = Ok(answer);
445 break;
446 }
447 Err(err) => {
448 result = Err(err);
449 }
450 },
451 Err(err) => {
452 result = Err(err);
453 }
454 }
455
456 if retries_remaining > 0 {
457 eprintln!(
458 "Retrying search after it failed on query {:?} with {:?}",
459 query, result
460 );
461 retries_remaining -= 1;
462 } else {
463 eprintln!(
464 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
465 query, result
466 );
467 break;
468 }
469 }
470
471 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
472 };
473
474 let mut project_covered_result_count = 0;
475 let mut project_overlapped_result_count = 0;
476 let mut project_covered_file_count = 0;
477 let mut covered_result_indices = Vec::new();
478 for expected_result in &query.expected_results {
479 let mut file_matched = false;
480 let mut range_overlapped = false;
481 let mut range_covered = false;
482
483 for (ix, result) in results.iter().enumerate() {
484 if result.path.as_ref() == Path::new(&expected_result.file) {
485 file_matched = true;
486 let start_matched = result.row_range.contains(expected_result.lines.start());
487 let end_matched = result.row_range.contains(expected_result.lines.end());
488
489 if start_matched || end_matched {
490 range_overlapped = true;
491 }
492
493 if start_matched && end_matched {
494 range_covered = true;
495 covered_result_indices.push(ix);
496 break;
497 }
498 }
499 }
500
501 if range_covered {
502 project_covered_result_count += 1
503 };
504 if range_overlapped {
505 project_overlapped_result_count += 1
506 };
507 if file_matched {
508 project_covered_file_count += 1
509 };
510 }
511 let outcome_repo = evaluation_project.repo.clone();
512
513 let query_results = EvaluationQueryOutcome {
514 repo: outcome_repo,
515 query: query.query,
516 total_result_count: query.expected_results.len(),
517 covered_result_count: project_covered_result_count,
518 overlapped_result_count: project_overlapped_result_count,
519 covered_file_count: project_covered_file_count,
520 expected_results: query.expected_results,
521 actual_results: results
522 .iter()
523 .map(|result| EvaluationSearchResult {
524 file: result.path.to_string_lossy().to_string(),
525 lines: result.row_range.clone(),
526 })
527 .collect(),
528 covered_result_indices,
529 };
530
531 counts.overlapped_results += query_results.overlapped_result_count;
532 counts.covered_results += query_results.covered_result_count;
533 counts.covered_files += query_results.covered_file_count;
534 counts.total_results += query_results.total_result_count;
535
536 println!("{}", serde_json::to_string(&query_results)?);
537 }
538
539 user_store.update(cx, |_, _| {
540 drop(semantic_index);
541 drop(project);
542 drop(worktree);
543 drop(project_index);
544 })
545}
546
547async fn wait_for_indexing_complete(
548 project_index: &Entity<ProjectIndex>,
549 cx: &mut AsyncApp,
550 timeout: Option<Duration>,
551) {
552 let (tx, rx) = bounded(1);
553 let subscription = cx.update(|cx| {
554 cx.subscribe(project_index, move |_, event, _| {
555 if let Status::Idle = event {
556 let _ = tx.try_send(*event);
557 }
558 })
559 });
560
561 let result = match timeout {
562 Some(timeout_duration) => {
563 smol::future::or(
564 async {
565 rx.recv().await.map_err(|_| ())?;
566 Ok(())
567 },
568 async {
569 Timer::after(timeout_duration).await;
570 Err(())
571 },
572 )
573 .await
574 }
575 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
576 };
577
578 match result {
579 Ok(_) => (),
580 Err(_) => {
581 if let Some(timeout) = timeout {
582 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
583 }
584 }
585 }
586
587 drop(subscription);
588}
589
590async fn fetch_eval_repos(
591 executor: &BackgroundExecutor,
592 http_client: &dyn HttpClient,
593) -> Result<()> {
594 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
595 let evaluations_path = dataset_dir.join("evaluations.json");
596 let repos_dir = Path::new(EVAL_REPOS_DIR);
597
598 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
599 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
600
601 eprintln!("Fetching evaluation repositories...");
602
603 executor
604 .scoped(move |scope| {
605 let done_count = Arc::new(AtomicUsize::new(0));
606 let len = evaluations.len();
607 for chunk in evaluations.chunks(evaluations.len() / 8) {
608 let chunk = chunk.to_vec();
609 let done_count = done_count.clone();
610 scope.spawn(async move {
611 for EvaluationProject { repo, sha, .. } in chunk {
612 eprint!(
613 "\rFetching evaluation repositories ({}/{})...",
614 done_count.load(SeqCst),
615 len,
616 );
617
618 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
619 done_count.fetch_add(1, SeqCst);
620 }
621 });
622 }
623 })
624 .await;
625
626 Ok(())
627}
628
629async fn fetch_eval_repo(
630 repo: String,
631 sha: String,
632 repos_dir: &Path,
633 http_client: &dyn HttpClient,
634) {
635 let Some((owner, repo_name)) = repo.split_once('/') else {
636 return;
637 };
638 let repo_dir = repos_dir.join(owner).join(repo_name);
639 fs::create_dir_all(&repo_dir).unwrap();
640 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
641 if skip_eval_path.exists() {
642 return;
643 }
644 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
645 if head_content.trim() == sha {
646 return;
647 }
648 }
649 let repo_response = http_client
650 .send(
651 http_client::Request::builder()
652 .method(Method::HEAD)
653 .uri(format!("https://github.com/{}", repo))
654 .body(Default::default())
655 .expect(""),
656 )
657 .await
658 .expect("failed to check github repo");
659 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
660 fs::write(&skip_eval_path, "").unwrap();
661 eprintln!(
662 "Repo {repo} is no longer public ({:?}). Skipping",
663 repo_response.status()
664 );
665 return;
666 }
667 if !repo_dir.join(".git").exists() {
668 let init_output = util::command::new_std_command("git")
669 .current_dir(&repo_dir)
670 .args(&["init"])
671 .output()
672 .unwrap();
673 if !init_output.status.success() {
674 eprintln!(
675 "Failed to initialize git repository for {}: {}",
676 repo,
677 String::from_utf8_lossy(&init_output.stderr)
678 );
679 return;
680 }
681 }
682 let url = format!("https://github.com/{}.git", repo);
683 util::command::new_std_command("git")
684 .current_dir(&repo_dir)
685 .args(&["remote", "add", "-f", "origin", &url])
686 .stdin(Stdio::null())
687 .output()
688 .unwrap();
689 let fetch_output = util::command::new_std_command("git")
690 .current_dir(&repo_dir)
691 .args(&["fetch", "--depth", "1", "origin", &sha])
692 .stdin(Stdio::null())
693 .output()
694 .unwrap();
695 if !fetch_output.status.success() {
696 eprintln!(
697 "Failed to fetch {} for {}: {}",
698 sha,
699 repo,
700 String::from_utf8_lossy(&fetch_output.stderr)
701 );
702 return;
703 }
704 let checkout_output = util::command::new_std_command("git")
705 .current_dir(&repo_dir)
706 .args(&["checkout", &sha])
707 .output()
708 .unwrap();
709
710 if !checkout_output.status.success() {
711 eprintln!(
712 "Failed to checkout {} for {}: {}",
713 sha,
714 repo,
715 String::from_utf8_lossy(&checkout_output.stderr)
716 );
717 }
718}