1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::NodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use reqwest_client::ReqwestClient;
16use semantic_index::{
17 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
18};
19use serde::{Deserialize, Serialize};
20use settings::SettingsStore;
21use smol::channel::bounded;
22use smol::io::AsyncReadExt;
23use smol::Timer;
24use std::ops::RangeInclusive;
25use std::path::PathBuf;
26use std::time::Duration;
27use std::{
28 fs,
29 path::Path,
30 process::{exit, Stdio},
31 sync::{
32 atomic::{AtomicUsize, Ordering::SeqCst},
33 Arc,
34 },
35};
36
37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
39const EVAL_DB_PATH: &'static str = "target/eval_db";
40const SEARCH_RESULT_LIMIT: usize = 8;
41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
42
43#[derive(clap::Parser)]
44#[command(author, version, about, long_about = None)]
45struct Cli {
46 #[command(subcommand)]
47 command: Commands,
48}
49
50#[derive(clap::Subcommand)]
51enum Commands {
52 Fetch {},
53 Run {
54 #[arg(long)]
55 repo: Option<String>,
56 },
57}
58
59#[derive(Clone, Deserialize, Serialize)]
60struct EvaluationProject {
61 repo: String,
62 sha: String,
63 queries: Vec<EvaluationQuery>,
64}
65
66#[derive(Clone, Debug, Deserialize, Serialize)]
67struct EvaluationQuery {
68 query: String,
69 expected_results: Vec<EvaluationSearchResult>,
70}
71
72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
73struct EvaluationSearchResult {
74 file: String,
75 lines: RangeInclusive<u32>,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79struct EvaluationProjectOutcome {
80 repo: String,
81 sha: String,
82 queries: Vec<EvaluationQueryOutcome>,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86struct EvaluationQueryOutcome {
87 repo: String,
88 query: String,
89 expected_results: Vec<EvaluationSearchResult>,
90 actual_results: Vec<EvaluationSearchResult>,
91 covered_file_count: usize,
92 overlapped_result_count: usize,
93 covered_result_count: usize,
94 total_result_count: usize,
95 covered_result_indices: Vec<usize>,
96}
97
98fn main() -> Result<()> {
99 let cli = Cli::parse();
100 env_logger::init();
101
102 gpui::Application::headless().run(move |cx| {
103 let executor = cx.background_executor().clone();
104 let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
105 cx.set_http_client(client.clone());
106 match cli.command {
107 Commands::Fetch {} => {
108 executor
109 .clone()
110 .spawn(async move {
111 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
112 eprintln!("Error: {}", err);
113 exit(1);
114 }
115 exit(0);
116 })
117 .detach();
118 }
119 Commands::Run { repo } => {
120 cx.spawn(|mut cx| async move {
121 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
122 eprintln!("Error: {}", err);
123 exit(1);
124 }
125 exit(0);
126 })
127 .detach();
128 }
129 }
130 });
131
132 Ok(())
133}
134
135async fn fetch_evaluation_resources(
136 http_client: Arc<dyn HttpClient>,
137 executor: &BackgroundExecutor,
138) -> Result<()> {
139 fetch_code_search_net_resources(&*http_client).await?;
140 fetch_eval_repos(executor, &*http_client).await?;
141 Ok(())
142}
143
144async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
145 eprintln!("Fetching CodeSearchNet evaluations...");
146
147 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
148
149 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
150 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
151
152 // Fetch the annotations CSV, which contains the human-annotated search relevances
153 let annotations_path = dataset_dir.join("annotations.csv");
154 let annotations_csv_content = if annotations_path.exists() {
155 fs::read_to_string(&annotations_path).expect("failed to read annotations")
156 } else {
157 let response = http_client
158 .get(annotations_url, Default::default(), true)
159 .await
160 .expect("failed to fetch annotations csv");
161 let mut body = String::new();
162 response
163 .into_body()
164 .read_to_string(&mut body)
165 .await
166 .expect("failed to read annotations.csv response");
167 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
168 body
169 };
170
171 // Parse the annotations CSV. Skip over queries with zero relevance.
172 let rows = annotations_csv_content.lines().filter_map(|line| {
173 let mut values = line.split(',');
174 let _language = values.next()?;
175 let query = values.next()?;
176 let github_url = values.next()?;
177 let score = values.next()?;
178
179 if score == "0" {
180 return None;
181 }
182
183 let url_path = github_url.strip_prefix("https://github.com/")?;
184 let (url_path, hash) = url_path.split_once('#')?;
185 let (repo_name, url_path) = url_path.split_once("/blob/")?;
186 let (sha, file_path) = url_path.split_once('/')?;
187 let line_range = if let Some((start, end)) = hash.split_once('-') {
188 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
189 } else {
190 let row = hash.strip_prefix("L")?.parse().ok()?;
191 row..=row
192 };
193 Some((repo_name, sha, query, file_path, line_range))
194 });
195
196 // Group the annotations by repo and sha.
197 let mut evaluations_by_repo = BTreeMap::new();
198 for (repo_name, sha, query, file_path, lines) in rows {
199 let evaluation_project = evaluations_by_repo
200 .entry((repo_name, sha))
201 .or_insert_with(|| EvaluationProject {
202 repo: repo_name.to_string(),
203 sha: sha.to_string(),
204 queries: Vec::new(),
205 });
206
207 let ix = evaluation_project
208 .queries
209 .iter()
210 .position(|entry| entry.query == query)
211 .unwrap_or_else(|| {
212 evaluation_project.queries.push(EvaluationQuery {
213 query: query.to_string(),
214 expected_results: Vec::new(),
215 });
216 evaluation_project.queries.len() - 1
217 });
218 let results = &mut evaluation_project.queries[ix].expected_results;
219 let result = EvaluationSearchResult {
220 file: file_path.to_string(),
221 lines,
222 };
223 if !results.contains(&result) {
224 results.push(result);
225 }
226 }
227
228 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
229 let evaluations_path = dataset_dir.join("evaluations.json");
230 fs::write(
231 &evaluations_path,
232 serde_json::to_vec_pretty(&evaluations).unwrap(),
233 )
234 .unwrap();
235
236 eprintln!(
237 "Fetched CodeSearchNet evaluations into {}",
238 evaluations_path.display()
239 );
240
241 Ok(())
242}
243
244#[derive(Default, Debug)]
245struct Counts {
246 covered_results: usize,
247 overlapped_results: usize,
248 covered_files: usize,
249 total_results: usize,
250}
251
252async fn run_evaluation(
253 only_repo: Option<String>,
254 executor: &BackgroundExecutor,
255 cx: &mut AsyncApp,
256) -> Result<()> {
257 let mut http_client = None;
258 cx.update(|cx| {
259 let mut store = SettingsStore::new(cx);
260 store
261 .set_default_settings(settings::default_settings().as_ref(), cx)
262 .unwrap();
263 cx.set_global(store);
264 client::init_settings(cx);
265 language::init(cx);
266 Project::init_settings(cx);
267 http_client = Some(cx.http_client());
268 cx.update_flags(false, vec![]);
269 })
270 .unwrap();
271 let http_client = http_client.unwrap();
272 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
273 let evaluations_path = dataset_dir.join("evaluations.json");
274 let repos_dir = Path::new(EVAL_REPOS_DIR);
275 let db_path = Path::new(EVAL_DB_PATH);
276 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
277 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
278 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
279 let clock = Arc::new(RealSystemClock);
280 let client = cx
281 .update(|cx| {
282 Client::new(
283 clock,
284 Arc::new(http_client::HttpClientWithUrl::new(
285 http_client.clone(),
286 "https://zed.dev",
287 None,
288 )),
289 cx,
290 )
291 })
292 .unwrap();
293 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
294 let node_runtime = NodeRuntime::unavailable();
295
296 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
297 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
298
299 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
300 http_client.clone(),
301 OpenAiEmbeddingModel::TextEmbedding3Small,
302 open_ai::OPEN_AI_API_URL.to_string(),
303 api_key,
304 ));
305
306 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
307 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
308 .unwrap();
309
310 let mut counts = Counts::default();
311 eprint!("Running evals.");
312
313 let mut failures = Vec::new();
314
315 for evaluation_project in evaluations {
316 if only_repo
317 .as_ref()
318 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
319 {
320 continue;
321 }
322
323 eprint!("\r\x1B[2K");
324 eprint!(
325 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
326 counts.covered_results,
327 counts.total_results,
328 counts.overlapped_results,
329 counts.total_results,
330 counts.covered_files,
331 counts.total_results,
332 evaluation_project.repo
333 );
334
335 let repo_dir = repos_dir.join(&evaluation_project.repo);
336 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
337 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
338 continue;
339 }
340
341 let repo_db_path =
342 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
343
344 let project = cx
345 .update(|cx| {
346 Project::local(
347 client.clone(),
348 node_runtime.clone(),
349 user_store.clone(),
350 language_registry.clone(),
351 fs.clone(),
352 None,
353 cx,
354 )
355 })
356 .unwrap();
357
358 let repo = evaluation_project.repo.clone();
359 if let Err(err) = run_eval_project(
360 evaluation_project,
361 &user_store,
362 repo_db_path,
363 &repo_dir,
364 &mut counts,
365 project,
366 embedding_provider.clone(),
367 fs.clone(),
368 cx,
369 )
370 .await
371 {
372 eprintln!("{repo} eval failed with error: {:?}", err);
373
374 failures.push((repo, err));
375 }
376 }
377
378 eprintln!(
379 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
380 counts.covered_results,
381 counts.total_results,
382 counts.overlapped_results,
383 counts.total_results,
384 counts.covered_files,
385 counts.total_results,
386 failures.len(),
387 );
388
389 if failures.is_empty() {
390 Ok(())
391 } else {
392 eprintln!("Failures:\n");
393
394 for (index, (repo, failure)) in failures.iter().enumerate() {
395 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
396 }
397
398 Err(anyhow::anyhow!("Some evals failed."))
399 }
400}
401
402async fn run_eval_project(
403 evaluation_project: EvaluationProject,
404 user_store: &Entity<UserStore>,
405 repo_db_path: PathBuf,
406 repo_dir: &Path,
407 counts: &mut Counts,
408 project: Entity<Project>,
409 embedding_provider: Arc<dyn EmbeddingProvider>,
410 fs: Arc<dyn Fs>,
411 cx: &mut AsyncApp,
412) -> Result<(), anyhow::Error> {
413 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
414
415 let (worktree, _) = project
416 .update(cx, |project, cx| {
417 project.find_or_create_worktree(repo_dir, true, cx)
418 })?
419 .await?;
420
421 worktree
422 .update(cx, |worktree, _| {
423 worktree.as_local().unwrap().scan_complete()
424 })?
425 .await;
426
427 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
428 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
429
430 for query in evaluation_project.queries {
431 let results = {
432 // Retry search up to 3 times in case of timeout, network failure, etc.
433 let mut retries_remaining = 3;
434 let mut result;
435
436 loop {
437 match cx.update(|cx| {
438 let project_index = project_index.read(cx);
439 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
440 }) {
441 Ok(task) => match task.await {
442 Ok(answer) => {
443 result = Ok(answer);
444 break;
445 }
446 Err(err) => {
447 result = Err(err);
448 }
449 },
450 Err(err) => {
451 result = Err(err);
452 }
453 }
454
455 if retries_remaining > 0 {
456 eprintln!(
457 "Retrying search after it failed on query {:?} with {:?}",
458 query, result
459 );
460 retries_remaining -= 1;
461 } else {
462 eprintln!(
463 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
464 query, result
465 );
466 break;
467 }
468 }
469
470 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
471 };
472
473 let mut project_covered_result_count = 0;
474 let mut project_overlapped_result_count = 0;
475 let mut project_covered_file_count = 0;
476 let mut covered_result_indices = Vec::new();
477 for expected_result in &query.expected_results {
478 let mut file_matched = false;
479 let mut range_overlapped = false;
480 let mut range_covered = false;
481
482 for (ix, result) in results.iter().enumerate() {
483 if result.path.as_ref() == Path::new(&expected_result.file) {
484 file_matched = true;
485 let start_matched = result.row_range.contains(&expected_result.lines.start());
486 let end_matched = result.row_range.contains(&expected_result.lines.end());
487
488 if start_matched || end_matched {
489 range_overlapped = true;
490 }
491
492 if start_matched && end_matched {
493 range_covered = true;
494 covered_result_indices.push(ix);
495 break;
496 }
497 }
498 }
499
500 if range_covered {
501 project_covered_result_count += 1
502 };
503 if range_overlapped {
504 project_overlapped_result_count += 1
505 };
506 if file_matched {
507 project_covered_file_count += 1
508 };
509 }
510 let outcome_repo = evaluation_project.repo.clone();
511
512 let query_results = EvaluationQueryOutcome {
513 repo: outcome_repo,
514 query: query.query,
515 total_result_count: query.expected_results.len(),
516 covered_result_count: project_covered_result_count,
517 overlapped_result_count: project_overlapped_result_count,
518 covered_file_count: project_covered_file_count,
519 expected_results: query.expected_results,
520 actual_results: results
521 .iter()
522 .map(|result| EvaluationSearchResult {
523 file: result.path.to_string_lossy().to_string(),
524 lines: result.row_range.clone(),
525 })
526 .collect(),
527 covered_result_indices,
528 };
529
530 counts.overlapped_results += query_results.overlapped_result_count;
531 counts.covered_results += query_results.covered_result_count;
532 counts.covered_files += query_results.covered_file_count;
533 counts.total_results += query_results.total_result_count;
534
535 println!("{}", serde_json::to_string(&query_results)?);
536 }
537
538 user_store.update(cx, |_, _| {
539 drop(semantic_index);
540 drop(project);
541 drop(worktree);
542 drop(project_index);
543 })
544}
545
546async fn wait_for_indexing_complete(
547 project_index: &Entity<ProjectIndex>,
548 cx: &mut AsyncApp,
549 timeout: Option<Duration>,
550) {
551 let (tx, rx) = bounded(1);
552 let subscription = cx.update(|cx| {
553 cx.subscribe(project_index, move |_, event, _| {
554 if let Status::Idle = event {
555 let _ = tx.try_send(*event);
556 }
557 })
558 });
559
560 let result = match timeout {
561 Some(timeout_duration) => {
562 smol::future::or(
563 async {
564 rx.recv().await.map_err(|_| ())?;
565 Ok(())
566 },
567 async {
568 Timer::after(timeout_duration).await;
569 Err(())
570 },
571 )
572 .await
573 }
574 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
575 };
576
577 match result {
578 Ok(_) => (),
579 Err(_) => {
580 if let Some(timeout) = timeout {
581 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
582 }
583 }
584 }
585
586 drop(subscription);
587}
588
589async fn fetch_eval_repos(
590 executor: &BackgroundExecutor,
591 http_client: &dyn HttpClient,
592) -> Result<()> {
593 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
594 let evaluations_path = dataset_dir.join("evaluations.json");
595 let repos_dir = Path::new(EVAL_REPOS_DIR);
596
597 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
598 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
599
600 eprintln!("Fetching evaluation repositories...");
601
602 executor
603 .scoped(move |scope| {
604 let done_count = Arc::new(AtomicUsize::new(0));
605 let len = evaluations.len();
606 for chunk in evaluations.chunks(evaluations.len() / 8) {
607 let chunk = chunk.to_vec();
608 let done_count = done_count.clone();
609 scope.spawn(async move {
610 for EvaluationProject { repo, sha, .. } in chunk {
611 eprint!(
612 "\rFetching evaluation repositories ({}/{})...",
613 done_count.load(SeqCst),
614 len,
615 );
616
617 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
618 done_count.fetch_add(1, SeqCst);
619 }
620 });
621 }
622 })
623 .await;
624
625 Ok(())
626}
627
628async fn fetch_eval_repo(
629 repo: String,
630 sha: String,
631 repos_dir: &Path,
632 http_client: &dyn HttpClient,
633) {
634 let Some((owner, repo_name)) = repo.split_once('/') else {
635 return;
636 };
637 let repo_dir = repos_dir.join(owner).join(repo_name);
638 fs::create_dir_all(&repo_dir).unwrap();
639 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
640 if skip_eval_path.exists() {
641 return;
642 }
643 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
644 if head_content.trim() == sha {
645 return;
646 }
647 }
648 let repo_response = http_client
649 .send(
650 http_client::Request::builder()
651 .method(Method::HEAD)
652 .uri(format!("https://github.com/{}", repo))
653 .body(Default::default())
654 .expect(""),
655 )
656 .await
657 .expect("failed to check github repo");
658 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
659 fs::write(&skip_eval_path, "").unwrap();
660 eprintln!(
661 "Repo {repo} is no longer public ({:?}). Skipping",
662 repo_response.status()
663 );
664 return;
665 }
666 if !repo_dir.join(".git").exists() {
667 let init_output = util::command::new_std_command("git")
668 .current_dir(&repo_dir)
669 .args(&["init"])
670 .output()
671 .unwrap();
672 if !init_output.status.success() {
673 eprintln!(
674 "Failed to initialize git repository for {}: {}",
675 repo,
676 String::from_utf8_lossy(&init_output.stderr)
677 );
678 return;
679 }
680 }
681 let url = format!("https://github.com/{}.git", repo);
682 util::command::new_std_command("git")
683 .current_dir(&repo_dir)
684 .args(&["remote", "add", "-f", "origin", &url])
685 .stdin(Stdio::null())
686 .output()
687 .unwrap();
688 let fetch_output = util::command::new_std_command("git")
689 .current_dir(&repo_dir)
690 .args(&["fetch", "--depth", "1", "origin", &sha])
691 .stdin(Stdio::null())
692 .output()
693 .unwrap();
694 if !fetch_output.status.success() {
695 eprintln!(
696 "Failed to fetch {} for {}: {}",
697 sha,
698 repo,
699 String::from_utf8_lossy(&fetch_output.stderr)
700 );
701 return;
702 }
703 let checkout_output = util::command::new_std_command("git")
704 .current_dir(&repo_dir)
705 .args(&["checkout", &sha])
706 .output()
707 .unwrap();
708
709 if !checkout_output.status.success() {
710 eprintln!(
711 "Failed to checkout {} for {}: {}",
712 sha,
713 repo,
714 String::from_utf8_lossy(&checkout_output.stderr)
715 );
716 }
717}