1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::NodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use semantic_index::{
16 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
17};
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::channel::bounded;
21use smol::io::AsyncReadExt;
22use smol::Timer;
23use std::ops::RangeInclusive;
24use std::path::PathBuf;
25use std::time::Duration;
26use std::{
27 fs,
28 path::Path,
29 process::{exit, Command, Stdio},
30 sync::{
31 atomic::{AtomicUsize, Ordering::SeqCst},
32 Arc,
33 },
34};
35use ureq_client::UreqClient;
36
37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
39const EVAL_DB_PATH: &'static str = "target/eval_db";
40const SEARCH_RESULT_LIMIT: usize = 8;
41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
42
43#[derive(clap::Parser)]
44#[command(author, version, about, long_about = None)]
45struct Cli {
46 #[command(subcommand)]
47 command: Commands,
48}
49
50#[derive(clap::Subcommand)]
51enum Commands {
52 Fetch {},
53 Run {
54 #[arg(long)]
55 repo: Option<String>,
56 },
57}
58
59#[derive(Clone, Deserialize, Serialize)]
60struct EvaluationProject {
61 repo: String,
62 sha: String,
63 queries: Vec<EvaluationQuery>,
64}
65
66#[derive(Clone, Debug, Deserialize, Serialize)]
67struct EvaluationQuery {
68 query: String,
69 expected_results: Vec<EvaluationSearchResult>,
70}
71
72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
73struct EvaluationSearchResult {
74 file: String,
75 lines: RangeInclusive<u32>,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79struct EvaluationProjectOutcome {
80 repo: String,
81 sha: String,
82 queries: Vec<EvaluationQueryOutcome>,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86struct EvaluationQueryOutcome {
87 repo: String,
88 query: String,
89 expected_results: Vec<EvaluationSearchResult>,
90 actual_results: Vec<EvaluationSearchResult>,
91 covered_file_count: usize,
92 overlapped_result_count: usize,
93 covered_result_count: usize,
94 total_result_count: usize,
95 covered_result_indices: Vec<usize>,
96}
97
98fn main() -> Result<()> {
99 let cli = Cli::parse();
100 env_logger::init();
101
102 gpui::App::headless().run(move |cx| {
103 let executor = cx.background_executor().clone();
104 let client = Arc::new(UreqClient::new(
105 None,
106 "Zed LLM evals".to_string(),
107 executor.clone(),
108 ));
109 cx.set_http_client(client.clone());
110 match cli.command {
111 Commands::Fetch {} => {
112 executor
113 .clone()
114 .spawn(async move {
115 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
116 eprintln!("Error: {}", err);
117 exit(1);
118 }
119 exit(0);
120 })
121 .detach();
122 }
123 Commands::Run { repo } => {
124 cx.spawn(|mut cx| async move {
125 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
126 eprintln!("Error: {}", err);
127 exit(1);
128 }
129 exit(0);
130 })
131 .detach();
132 }
133 }
134 });
135
136 Ok(())
137}
138
139async fn fetch_evaluation_resources(
140 http_client: Arc<dyn HttpClient>,
141 executor: &BackgroundExecutor,
142) -> Result<()> {
143 fetch_code_search_net_resources(&*http_client).await?;
144 fetch_eval_repos(executor, &*http_client).await?;
145 Ok(())
146}
147
148async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
149 eprintln!("Fetching CodeSearchNet evaluations...");
150
151 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
152
153 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
154 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
155
156 // Fetch the annotations CSV, which contains the human-annotated search relevances
157 let annotations_path = dataset_dir.join("annotations.csv");
158 let annotations_csv_content = if annotations_path.exists() {
159 fs::read_to_string(&annotations_path).expect("failed to read annotations")
160 } else {
161 let response = http_client
162 .get(annotations_url, Default::default(), true)
163 .await
164 .expect("failed to fetch annotations csv");
165 let mut body = String::new();
166 response
167 .into_body()
168 .read_to_string(&mut body)
169 .await
170 .expect("failed to read annotations.csv response");
171 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
172 body
173 };
174
175 // Parse the annotations CSV. Skip over queries with zero relevance.
176 let rows = annotations_csv_content.lines().filter_map(|line| {
177 let mut values = line.split(',');
178 let _language = values.next()?;
179 let query = values.next()?;
180 let github_url = values.next()?;
181 let score = values.next()?;
182
183 if score == "0" {
184 return None;
185 }
186
187 let url_path = github_url.strip_prefix("https://github.com/")?;
188 let (url_path, hash) = url_path.split_once('#')?;
189 let (repo_name, url_path) = url_path.split_once("/blob/")?;
190 let (sha, file_path) = url_path.split_once('/')?;
191 let line_range = if let Some((start, end)) = hash.split_once('-') {
192 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
193 } else {
194 let row = hash.strip_prefix("L")?.parse().ok()?;
195 row..=row
196 };
197 Some((repo_name, sha, query, file_path, line_range))
198 });
199
200 // Group the annotations by repo and sha.
201 let mut evaluations_by_repo = BTreeMap::new();
202 for (repo_name, sha, query, file_path, lines) in rows {
203 let evaluation_project = evaluations_by_repo
204 .entry((repo_name, sha))
205 .or_insert_with(|| EvaluationProject {
206 repo: repo_name.to_string(),
207 sha: sha.to_string(),
208 queries: Vec::new(),
209 });
210
211 let ix = evaluation_project
212 .queries
213 .iter()
214 .position(|entry| entry.query == query)
215 .unwrap_or_else(|| {
216 evaluation_project.queries.push(EvaluationQuery {
217 query: query.to_string(),
218 expected_results: Vec::new(),
219 });
220 evaluation_project.queries.len() - 1
221 });
222 let results = &mut evaluation_project.queries[ix].expected_results;
223 let result = EvaluationSearchResult {
224 file: file_path.to_string(),
225 lines,
226 };
227 if !results.contains(&result) {
228 results.push(result);
229 }
230 }
231
232 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
233 let evaluations_path = dataset_dir.join("evaluations.json");
234 fs::write(
235 &evaluations_path,
236 serde_json::to_vec_pretty(&evaluations).unwrap(),
237 )
238 .unwrap();
239
240 eprintln!(
241 "Fetched CodeSearchNet evaluations into {}",
242 evaluations_path.display()
243 );
244
245 Ok(())
246}
247
248#[derive(Default, Debug)]
249struct Counts {
250 covered_results: usize,
251 overlapped_results: usize,
252 covered_files: usize,
253 total_results: usize,
254}
255
256async fn run_evaluation(
257 only_repo: Option<String>,
258 executor: &BackgroundExecutor,
259 cx: &mut AsyncAppContext,
260) -> Result<()> {
261 let mut http_client = None;
262 cx.update(|cx| {
263 let mut store = SettingsStore::new(cx);
264 store
265 .set_default_settings(settings::default_settings().as_ref(), cx)
266 .unwrap();
267 cx.set_global(store);
268 client::init_settings(cx);
269 language::init(cx);
270 Project::init_settings(cx);
271 http_client = Some(cx.http_client());
272 cx.update_flags(false, vec![]);
273 })
274 .unwrap();
275 let http_client = http_client.unwrap();
276 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
277 let evaluations_path = dataset_dir.join("evaluations.json");
278 let repos_dir = Path::new(EVAL_REPOS_DIR);
279 let db_path = Path::new(EVAL_DB_PATH);
280 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
281 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
282 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
283 let clock = Arc::new(RealSystemClock);
284 let client = cx
285 .update(|cx| {
286 Client::new(
287 clock,
288 Arc::new(http_client::HttpClientWithUrl::new(
289 http_client.clone(),
290 "https://zed.dev",
291 None,
292 )),
293 cx,
294 )
295 })
296 .unwrap();
297 let user_store = cx
298 .new_model(|cx| UserStore::new(client.clone(), cx))
299 .unwrap();
300 let node_runtime = NodeRuntime::unavailable();
301
302 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
303 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
304
305 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
306 http_client.clone(),
307 OpenAiEmbeddingModel::TextEmbedding3Small,
308 open_ai::OPEN_AI_API_URL.to_string(),
309 api_key,
310 ));
311
312 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
313 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
314 .unwrap();
315
316 let mut counts = Counts::default();
317 eprint!("Running evals.");
318
319 let mut failures = Vec::new();
320
321 for evaluation_project in evaluations {
322 if only_repo
323 .as_ref()
324 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
325 {
326 continue;
327 }
328
329 eprint!("\r\x1B[2K");
330 eprint!(
331 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
332 counts.covered_results,
333 counts.total_results,
334 counts.overlapped_results,
335 counts.total_results,
336 counts.covered_files,
337 counts.total_results,
338 evaluation_project.repo
339 );
340
341 let repo_dir = repos_dir.join(&evaluation_project.repo);
342 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
343 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
344 continue;
345 }
346
347 let repo_db_path =
348 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
349
350 let project = cx
351 .update(|cx| {
352 Project::local(
353 client.clone(),
354 node_runtime.clone(),
355 user_store.clone(),
356 language_registry.clone(),
357 fs.clone(),
358 None,
359 cx,
360 )
361 })
362 .unwrap();
363
364 let repo = evaluation_project.repo.clone();
365 if let Err(err) = run_eval_project(
366 evaluation_project,
367 &user_store,
368 repo_db_path,
369 &repo_dir,
370 &mut counts,
371 project,
372 embedding_provider.clone(),
373 fs.clone(),
374 cx,
375 )
376 .await
377 {
378 eprintln!("{repo} eval failed with error: {:?}", err);
379
380 failures.push((repo, err));
381 }
382 }
383
384 eprintln!(
385 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
386 counts.covered_results,
387 counts.total_results,
388 counts.overlapped_results,
389 counts.total_results,
390 counts.covered_files,
391 counts.total_results,
392 failures.len(),
393 );
394
395 if failures.is_empty() {
396 Ok(())
397 } else {
398 eprintln!("Failures:\n");
399
400 for (index, (repo, failure)) in failures.iter().enumerate() {
401 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
402 }
403
404 Err(anyhow::anyhow!("Some evals failed."))
405 }
406}
407
408#[allow(clippy::too_many_arguments)]
409async fn run_eval_project(
410 evaluation_project: EvaluationProject,
411 user_store: &Model<UserStore>,
412 repo_db_path: PathBuf,
413 repo_dir: &Path,
414 counts: &mut Counts,
415 project: Model<Project>,
416 embedding_provider: Arc<dyn EmbeddingProvider>,
417 fs: Arc<dyn Fs>,
418 cx: &mut AsyncAppContext,
419) -> Result<(), anyhow::Error> {
420 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
421
422 let (worktree, _) = project
423 .update(cx, |project, cx| {
424 project.find_or_create_worktree(repo_dir, true, cx)
425 })?
426 .await?;
427
428 worktree
429 .update(cx, |worktree, _| {
430 worktree.as_local().unwrap().scan_complete()
431 })?
432 .await;
433
434 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
435 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
436
437 for query in evaluation_project.queries {
438 let results = {
439 // Retry search up to 3 times in case of timeout, network failure, etc.
440 let mut retries_remaining = 3;
441 let mut result;
442
443 loop {
444 match cx.update(|cx| {
445 let project_index = project_index.read(cx);
446 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
447 }) {
448 Ok(task) => match task.await {
449 Ok(answer) => {
450 result = Ok(answer);
451 break;
452 }
453 Err(err) => {
454 result = Err(err);
455 }
456 },
457 Err(err) => {
458 result = Err(err);
459 }
460 }
461
462 if retries_remaining > 0 {
463 eprintln!(
464 "Retrying search after it failed on query {:?} with {:?}",
465 query, result
466 );
467 retries_remaining -= 1;
468 } else {
469 eprintln!(
470 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
471 query, result
472 );
473 break;
474 }
475 }
476
477 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
478 };
479
480 let mut project_covered_result_count = 0;
481 let mut project_overlapped_result_count = 0;
482 let mut project_covered_file_count = 0;
483 let mut covered_result_indices = Vec::new();
484 for expected_result in &query.expected_results {
485 let mut file_matched = false;
486 let mut range_overlapped = false;
487 let mut range_covered = false;
488
489 for (ix, result) in results.iter().enumerate() {
490 if result.path.as_ref() == Path::new(&expected_result.file) {
491 file_matched = true;
492 let start_matched = result.row_range.contains(&expected_result.lines.start());
493 let end_matched = result.row_range.contains(&expected_result.lines.end());
494
495 if start_matched || end_matched {
496 range_overlapped = true;
497 }
498
499 if start_matched && end_matched {
500 range_covered = true;
501 covered_result_indices.push(ix);
502 break;
503 }
504 }
505 }
506
507 if range_covered {
508 project_covered_result_count += 1
509 };
510 if range_overlapped {
511 project_overlapped_result_count += 1
512 };
513 if file_matched {
514 project_covered_file_count += 1
515 };
516 }
517 let outcome_repo = evaluation_project.repo.clone();
518
519 let query_results = EvaluationQueryOutcome {
520 repo: outcome_repo,
521 query: query.query,
522 total_result_count: query.expected_results.len(),
523 covered_result_count: project_covered_result_count,
524 overlapped_result_count: project_overlapped_result_count,
525 covered_file_count: project_covered_file_count,
526 expected_results: query.expected_results,
527 actual_results: results
528 .iter()
529 .map(|result| EvaluationSearchResult {
530 file: result.path.to_string_lossy().to_string(),
531 lines: result.row_range.clone(),
532 })
533 .collect(),
534 covered_result_indices,
535 };
536
537 counts.overlapped_results += query_results.overlapped_result_count;
538 counts.covered_results += query_results.covered_result_count;
539 counts.covered_files += query_results.covered_file_count;
540 counts.total_results += query_results.total_result_count;
541
542 println!("{}", serde_json::to_string(&query_results)?);
543 }
544
545 user_store.update(cx, |_, _| {
546 drop(semantic_index);
547 drop(project);
548 drop(worktree);
549 drop(project_index);
550 })
551}
552
553async fn wait_for_indexing_complete(
554 project_index: &Model<ProjectIndex>,
555 cx: &mut AsyncAppContext,
556 timeout: Option<Duration>,
557) {
558 let (tx, rx) = bounded(1);
559 let subscription = cx.update(|cx| {
560 cx.subscribe(project_index, move |_, event, _| {
561 if let Status::Idle = event {
562 let _ = tx.try_send(*event);
563 }
564 })
565 });
566
567 let result = match timeout {
568 Some(timeout_duration) => {
569 smol::future::or(
570 async {
571 rx.recv().await.map_err(|_| ())?;
572 Ok(())
573 },
574 async {
575 Timer::after(timeout_duration).await;
576 Err(())
577 },
578 )
579 .await
580 }
581 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
582 };
583
584 match result {
585 Ok(_) => (),
586 Err(_) => {
587 if let Some(timeout) = timeout {
588 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
589 }
590 }
591 }
592
593 drop(subscription);
594}
595
596async fn fetch_eval_repos(
597 executor: &BackgroundExecutor,
598 http_client: &dyn HttpClient,
599) -> Result<()> {
600 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
601 let evaluations_path = dataset_dir.join("evaluations.json");
602 let repos_dir = Path::new(EVAL_REPOS_DIR);
603
604 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
605 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
606
607 eprintln!("Fetching evaluation repositories...");
608
609 executor
610 .scoped(move |scope| {
611 let done_count = Arc::new(AtomicUsize::new(0));
612 let len = evaluations.len();
613 for chunk in evaluations.chunks(evaluations.len() / 8) {
614 let chunk = chunk.to_vec();
615 let done_count = done_count.clone();
616 scope.spawn(async move {
617 for EvaluationProject { repo, sha, .. } in chunk {
618 eprint!(
619 "\rFetching evaluation repositories ({}/{})...",
620 done_count.load(SeqCst),
621 len,
622 );
623
624 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
625 done_count.fetch_add(1, SeqCst);
626 }
627 });
628 }
629 })
630 .await;
631
632 Ok(())
633}
634
635async fn fetch_eval_repo(
636 repo: String,
637 sha: String,
638 repos_dir: &Path,
639 http_client: &dyn HttpClient,
640) {
641 let Some((owner, repo_name)) = repo.split_once('/') else {
642 return;
643 };
644 let repo_dir = repos_dir.join(owner).join(repo_name);
645 fs::create_dir_all(&repo_dir).unwrap();
646 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
647 if skip_eval_path.exists() {
648 return;
649 }
650 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
651 if head_content.trim() == sha {
652 return;
653 }
654 }
655 let repo_response = http_client
656 .send(
657 http_client::Request::builder()
658 .method(Method::HEAD)
659 .uri(format!("https://github.com/{}", repo))
660 .body(Default::default())
661 .expect(""),
662 )
663 .await
664 .expect("failed to check github repo");
665 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
666 fs::write(&skip_eval_path, "").unwrap();
667 eprintln!(
668 "Repo {repo} is no longer public ({:?}). Skipping",
669 repo_response.status()
670 );
671 return;
672 }
673 if !repo_dir.join(".git").exists() {
674 let init_output = Command::new("git")
675 .current_dir(&repo_dir)
676 .args(&["init"])
677 .output()
678 .unwrap();
679 if !init_output.status.success() {
680 eprintln!(
681 "Failed to initialize git repository for {}: {}",
682 repo,
683 String::from_utf8_lossy(&init_output.stderr)
684 );
685 return;
686 }
687 }
688 let url = format!("https://github.com/{}.git", repo);
689 Command::new("git")
690 .current_dir(&repo_dir)
691 .args(&["remote", "add", "-f", "origin", &url])
692 .stdin(Stdio::null())
693 .output()
694 .unwrap();
695 let fetch_output = Command::new("git")
696 .current_dir(&repo_dir)
697 .args(&["fetch", "--depth", "1", "origin", &sha])
698 .stdin(Stdio::null())
699 .output()
700 .unwrap();
701 if !fetch_output.status.success() {
702 eprintln!(
703 "Failed to fetch {} for {}: {}",
704 sha,
705 repo,
706 String::from_utf8_lossy(&fetch_output.stderr)
707 );
708 return;
709 }
710 let checkout_output = Command::new("git")
711 .current_dir(&repo_dir)
712 .args(&["checkout", &sha])
713 .output()
714 .unwrap();
715
716 if !checkout_output.status.success() {
717 eprintln!(
718 "Failed to checkout {} for {}: {}",
719 sha,
720 repo,
721 String::from_utf8_lossy(&checkout_output.stderr)
722 );
723 }
724}