1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
9use http_client::{HttpClient, Method};
10use language::LanguageRegistry;
11use node_runtime::NodeRuntime;
12use open_ai::OpenAiEmbeddingModel;
13use project::Project;
14use reqwest_client::ReqwestClient;
15use semantic_index::{
16 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
17};
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::channel::bounded;
21use smol::io::AsyncReadExt;
22use smol::Timer;
23use std::ops::RangeInclusive;
24use std::path::PathBuf;
25use std::time::Duration;
26use std::{
27 fs,
28 path::Path,
29 process::{exit, Stdio},
30 sync::{
31 atomic::{AtomicUsize, Ordering::SeqCst},
32 Arc,
33 },
34};
35
36const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
37const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
38const EVAL_DB_PATH: &'static str = "target/eval_db";
39const SEARCH_RESULT_LIMIT: usize = 8;
40const SKIP_EVAL_PATH: &'static str = ".skip_eval";
41
42#[derive(clap::Parser)]
43#[command(author, version, about, long_about = None)]
44struct Cli {
45 #[command(subcommand)]
46 command: Commands,
47}
48
49#[derive(clap::Subcommand)]
50enum Commands {
51 Fetch {},
52 Run {
53 #[arg(long)]
54 repo: Option<String>,
55 },
56}
57
58#[derive(Clone, Deserialize, Serialize)]
59struct EvaluationProject {
60 repo: String,
61 sha: String,
62 queries: Vec<EvaluationQuery>,
63}
64
65#[derive(Clone, Debug, Deserialize, Serialize)]
66struct EvaluationQuery {
67 query: String,
68 expected_results: Vec<EvaluationSearchResult>,
69}
70
71#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72struct EvaluationSearchResult {
73 file: String,
74 lines: RangeInclusive<u32>,
75}
76
77#[derive(Clone, Deserialize, Serialize)]
78struct EvaluationProjectOutcome {
79 repo: String,
80 sha: String,
81 queries: Vec<EvaluationQueryOutcome>,
82}
83
84#[derive(Clone, Debug, Deserialize, Serialize)]
85struct EvaluationQueryOutcome {
86 repo: String,
87 query: String,
88 expected_results: Vec<EvaluationSearchResult>,
89 actual_results: Vec<EvaluationSearchResult>,
90 covered_file_count: usize,
91 overlapped_result_count: usize,
92 covered_result_count: usize,
93 total_result_count: usize,
94 covered_result_indices: Vec<usize>,
95}
96
97fn main() -> Result<()> {
98 let cli = Cli::parse();
99 env_logger::init();
100
101 gpui::Application::headless().run(move |cx| {
102 let executor = cx.background_executor().clone();
103 let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
104 cx.set_http_client(client.clone());
105 match cli.command {
106 Commands::Fetch {} => {
107 executor
108 .clone()
109 .spawn(async move {
110 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
111 eprintln!("Error: {}", err);
112 exit(1);
113 }
114 exit(0);
115 })
116 .detach();
117 }
118 Commands::Run { repo } => {
119 cx.spawn(async move |cx| {
120 if let Err(err) = run_evaluation(repo, &executor, cx).await {
121 eprintln!("Error: {}", err);
122 exit(1);
123 }
124 exit(0);
125 })
126 .detach();
127 }
128 }
129 });
130
131 Ok(())
132}
133
134async fn fetch_evaluation_resources(
135 http_client: Arc<dyn HttpClient>,
136 executor: &BackgroundExecutor,
137) -> Result<()> {
138 fetch_code_search_net_resources(&*http_client).await?;
139 fetch_eval_repos(executor, &*http_client).await?;
140 Ok(())
141}
142
143async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
144 eprintln!("Fetching CodeSearchNet evaluations...");
145
146 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
147
148 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
149 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
150
151 // Fetch the annotations CSV, which contains the human-annotated search relevances
152 let annotations_path = dataset_dir.join("annotations.csv");
153 let annotations_csv_content = if annotations_path.exists() {
154 fs::read_to_string(&annotations_path).expect("failed to read annotations")
155 } else {
156 let response = http_client
157 .get(annotations_url, Default::default(), true)
158 .await
159 .expect("failed to fetch annotations csv");
160 let mut body = String::new();
161 response
162 .into_body()
163 .read_to_string(&mut body)
164 .await
165 .expect("failed to read annotations.csv response");
166 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
167 body
168 };
169
170 // Parse the annotations CSV. Skip over queries with zero relevance.
171 let rows = annotations_csv_content.lines().filter_map(|line| {
172 let mut values = line.split(',');
173 let _language = values.next()?;
174 let query = values.next()?;
175 let github_url = values.next()?;
176 let score = values.next()?;
177
178 if score == "0" {
179 return None;
180 }
181
182 let url_path = github_url.strip_prefix("https://github.com/")?;
183 let (url_path, hash) = url_path.split_once('#')?;
184 let (repo_name, url_path) = url_path.split_once("/blob/")?;
185 let (sha, file_path) = url_path.split_once('/')?;
186 let line_range = if let Some((start, end)) = hash.split_once('-') {
187 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
188 } else {
189 let row = hash.strip_prefix("L")?.parse().ok()?;
190 row..=row
191 };
192 Some((repo_name, sha, query, file_path, line_range))
193 });
194
195 // Group the annotations by repo and sha.
196 let mut evaluations_by_repo = BTreeMap::new();
197 for (repo_name, sha, query, file_path, lines) in rows {
198 let evaluation_project = evaluations_by_repo
199 .entry((repo_name, sha))
200 .or_insert_with(|| EvaluationProject {
201 repo: repo_name.to_string(),
202 sha: sha.to_string(),
203 queries: Vec::new(),
204 });
205
206 let ix = evaluation_project
207 .queries
208 .iter()
209 .position(|entry| entry.query == query)
210 .unwrap_or_else(|| {
211 evaluation_project.queries.push(EvaluationQuery {
212 query: query.to_string(),
213 expected_results: Vec::new(),
214 });
215 evaluation_project.queries.len() - 1
216 });
217 let results = &mut evaluation_project.queries[ix].expected_results;
218 let result = EvaluationSearchResult {
219 file: file_path.to_string(),
220 lines,
221 };
222 if !results.contains(&result) {
223 results.push(result);
224 }
225 }
226
227 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
228 let evaluations_path = dataset_dir.join("evaluations.json");
229 fs::write(
230 &evaluations_path,
231 serde_json::to_vec_pretty(&evaluations).unwrap(),
232 )
233 .unwrap();
234
235 eprintln!(
236 "Fetched CodeSearchNet evaluations into {}",
237 evaluations_path.display()
238 );
239
240 Ok(())
241}
242
243#[derive(Default, Debug)]
244struct Counts {
245 covered_results: usize,
246 overlapped_results: usize,
247 covered_files: usize,
248 total_results: usize,
249}
250
251async fn run_evaluation(
252 only_repo: Option<String>,
253 executor: &BackgroundExecutor,
254 cx: &mut AsyncApp,
255) -> Result<()> {
256 let mut http_client = None;
257 cx.update(|cx| {
258 let mut store = SettingsStore::new(cx);
259 store
260 .set_default_settings(settings::default_settings().as_ref(), cx)
261 .unwrap();
262 cx.set_global(store);
263 client::init_settings(cx);
264 language::init(cx);
265 Project::init_settings(cx);
266 http_client = Some(cx.http_client());
267 cx.update_flags(false, vec![]);
268 })
269 .unwrap();
270 let http_client = http_client.unwrap();
271 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
272 let evaluations_path = dataset_dir.join("evaluations.json");
273 let repos_dir = Path::new(EVAL_REPOS_DIR);
274 let db_path = Path::new(EVAL_DB_PATH);
275 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
276 let fs = Arc::new(RealFs::new(None)) as Arc<dyn Fs>;
277 let clock = Arc::new(RealSystemClock);
278 let client = cx
279 .update(|cx| {
280 Client::new(
281 clock,
282 Arc::new(http_client::HttpClientWithUrl::new(
283 http_client.clone(),
284 "https://zed.dev",
285 None,
286 )),
287 cx,
288 )
289 })
290 .unwrap();
291 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
292 let node_runtime = NodeRuntime::unavailable();
293
294 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
295 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
296
297 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
298 http_client.clone(),
299 OpenAiEmbeddingModel::TextEmbedding3Small,
300 open_ai::OPEN_AI_API_URL.to_string(),
301 api_key,
302 ));
303
304 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
305 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
306 .unwrap();
307
308 let mut counts = Counts::default();
309 eprint!("Running evals.");
310
311 let mut failures = Vec::new();
312
313 for evaluation_project in evaluations {
314 if only_repo
315 .as_ref()
316 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
317 {
318 continue;
319 }
320
321 eprint!("\r\x1B[2K");
322 eprint!(
323 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
324 counts.covered_results,
325 counts.total_results,
326 counts.overlapped_results,
327 counts.total_results,
328 counts.covered_files,
329 counts.total_results,
330 evaluation_project.repo
331 );
332
333 let repo_dir = repos_dir.join(&evaluation_project.repo);
334 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
335 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
336 continue;
337 }
338
339 let repo_db_path =
340 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
341
342 let project = cx
343 .update(|cx| {
344 Project::local(
345 client.clone(),
346 node_runtime.clone(),
347 user_store.clone(),
348 language_registry.clone(),
349 fs.clone(),
350 None,
351 cx,
352 )
353 })
354 .unwrap();
355
356 let repo = evaluation_project.repo.clone();
357 if let Err(err) = run_eval_project(
358 evaluation_project,
359 &user_store,
360 repo_db_path,
361 &repo_dir,
362 &mut counts,
363 project,
364 embedding_provider.clone(),
365 fs.clone(),
366 cx,
367 )
368 .await
369 {
370 eprintln!("{repo} eval failed with error: {:?}", err);
371
372 failures.push((repo, err));
373 }
374 }
375
376 eprintln!(
377 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
378 counts.covered_results,
379 counts.total_results,
380 counts.overlapped_results,
381 counts.total_results,
382 counts.covered_files,
383 counts.total_results,
384 failures.len(),
385 );
386
387 if failures.is_empty() {
388 Ok(())
389 } else {
390 eprintln!("Failures:\n");
391
392 for (index, (repo, failure)) in failures.iter().enumerate() {
393 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
394 }
395
396 Err(anyhow::anyhow!("Some evals failed."))
397 }
398}
399
400async fn run_eval_project(
401 evaluation_project: EvaluationProject,
402 user_store: &Entity<UserStore>,
403 repo_db_path: PathBuf,
404 repo_dir: &Path,
405 counts: &mut Counts,
406 project: Entity<Project>,
407 embedding_provider: Arc<dyn EmbeddingProvider>,
408 fs: Arc<dyn Fs>,
409 cx: &mut AsyncApp,
410) -> Result<(), anyhow::Error> {
411 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
412
413 let (worktree, _) = project
414 .update(cx, |project, cx| {
415 project.find_or_create_worktree(repo_dir, true, cx)
416 })?
417 .await?;
418
419 worktree
420 .update(cx, |worktree, _| {
421 worktree.as_local().unwrap().scan_complete()
422 })?
423 .await;
424
425 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
426 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
427
428 for query in evaluation_project.queries {
429 let results = {
430 // Retry search up to 3 times in case of timeout, network failure, etc.
431 let mut retries_remaining = 3;
432 let mut result;
433
434 loop {
435 match cx.update(|cx| {
436 let project_index = project_index.read(cx);
437 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
438 }) {
439 Ok(task) => match task.await {
440 Ok(answer) => {
441 result = Ok(answer);
442 break;
443 }
444 Err(err) => {
445 result = Err(err);
446 }
447 },
448 Err(err) => {
449 result = Err(err);
450 }
451 }
452
453 if retries_remaining > 0 {
454 eprintln!(
455 "Retrying search after it failed on query {:?} with {:?}",
456 query, result
457 );
458 retries_remaining -= 1;
459 } else {
460 eprintln!(
461 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
462 query, result
463 );
464 break;
465 }
466 }
467
468 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
469 };
470
471 let mut project_covered_result_count = 0;
472 let mut project_overlapped_result_count = 0;
473 let mut project_covered_file_count = 0;
474 let mut covered_result_indices = Vec::new();
475 for expected_result in &query.expected_results {
476 let mut file_matched = false;
477 let mut range_overlapped = false;
478 let mut range_covered = false;
479
480 for (ix, result) in results.iter().enumerate() {
481 if result.path.as_ref() == Path::new(&expected_result.file) {
482 file_matched = true;
483 let start_matched = result.row_range.contains(&expected_result.lines.start());
484 let end_matched = result.row_range.contains(&expected_result.lines.end());
485
486 if start_matched || end_matched {
487 range_overlapped = true;
488 }
489
490 if start_matched && end_matched {
491 range_covered = true;
492 covered_result_indices.push(ix);
493 break;
494 }
495 }
496 }
497
498 if range_covered {
499 project_covered_result_count += 1
500 };
501 if range_overlapped {
502 project_overlapped_result_count += 1
503 };
504 if file_matched {
505 project_covered_file_count += 1
506 };
507 }
508 let outcome_repo = evaluation_project.repo.clone();
509
510 let query_results = EvaluationQueryOutcome {
511 repo: outcome_repo,
512 query: query.query,
513 total_result_count: query.expected_results.len(),
514 covered_result_count: project_covered_result_count,
515 overlapped_result_count: project_overlapped_result_count,
516 covered_file_count: project_covered_file_count,
517 expected_results: query.expected_results,
518 actual_results: results
519 .iter()
520 .map(|result| EvaluationSearchResult {
521 file: result.path.to_string_lossy().to_string(),
522 lines: result.row_range.clone(),
523 })
524 .collect(),
525 covered_result_indices,
526 };
527
528 counts.overlapped_results += query_results.overlapped_result_count;
529 counts.covered_results += query_results.covered_result_count;
530 counts.covered_files += query_results.covered_file_count;
531 counts.total_results += query_results.total_result_count;
532
533 println!("{}", serde_json::to_string(&query_results)?);
534 }
535
536 user_store.update(cx, |_, _| {
537 drop(semantic_index);
538 drop(project);
539 drop(worktree);
540 drop(project_index);
541 })
542}
543
544async fn wait_for_indexing_complete(
545 project_index: &Entity<ProjectIndex>,
546 cx: &mut AsyncApp,
547 timeout: Option<Duration>,
548) {
549 let (tx, rx) = bounded(1);
550 let subscription = cx.update(|cx| {
551 cx.subscribe(project_index, move |_, event, _| {
552 if let Status::Idle = event {
553 let _ = tx.try_send(*event);
554 }
555 })
556 });
557
558 let result = match timeout {
559 Some(timeout_duration) => {
560 smol::future::or(
561 async {
562 rx.recv().await.map_err(|_| ())?;
563 Ok(())
564 },
565 async {
566 Timer::after(timeout_duration).await;
567 Err(())
568 },
569 )
570 .await
571 }
572 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
573 };
574
575 match result {
576 Ok(_) => (),
577 Err(_) => {
578 if let Some(timeout) = timeout {
579 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
580 }
581 }
582 }
583
584 drop(subscription);
585}
586
587async fn fetch_eval_repos(
588 executor: &BackgroundExecutor,
589 http_client: &dyn HttpClient,
590) -> Result<()> {
591 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
592 let evaluations_path = dataset_dir.join("evaluations.json");
593 let repos_dir = Path::new(EVAL_REPOS_DIR);
594
595 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
596 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
597
598 eprintln!("Fetching evaluation repositories...");
599
600 executor
601 .scoped(move |scope| {
602 let done_count = Arc::new(AtomicUsize::new(0));
603 let len = evaluations.len();
604 for chunk in evaluations.chunks(evaluations.len() / 8) {
605 let chunk = chunk.to_vec();
606 let done_count = done_count.clone();
607 scope.spawn(async move {
608 for EvaluationProject { repo, sha, .. } in chunk {
609 eprint!(
610 "\rFetching evaluation repositories ({}/{})...",
611 done_count.load(SeqCst),
612 len,
613 );
614
615 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
616 done_count.fetch_add(1, SeqCst);
617 }
618 });
619 }
620 })
621 .await;
622
623 Ok(())
624}
625
626async fn fetch_eval_repo(
627 repo: String,
628 sha: String,
629 repos_dir: &Path,
630 http_client: &dyn HttpClient,
631) {
632 let Some((owner, repo_name)) = repo.split_once('/') else {
633 return;
634 };
635 let repo_dir = repos_dir.join(owner).join(repo_name);
636 fs::create_dir_all(&repo_dir).unwrap();
637 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
638 if skip_eval_path.exists() {
639 return;
640 }
641 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
642 if head_content.trim() == sha {
643 return;
644 }
645 }
646 let repo_response = http_client
647 .send(
648 http_client::Request::builder()
649 .method(Method::HEAD)
650 .uri(format!("https://github.com/{}", repo))
651 .body(Default::default())
652 .expect(""),
653 )
654 .await
655 .expect("failed to check github repo");
656 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
657 fs::write(&skip_eval_path, "").unwrap();
658 eprintln!(
659 "Repo {repo} is no longer public ({:?}). Skipping",
660 repo_response.status()
661 );
662 return;
663 }
664 if !repo_dir.join(".git").exists() {
665 let init_output = util::command::new_std_command("git")
666 .current_dir(&repo_dir)
667 .args(&["init"])
668 .output()
669 .unwrap();
670 if !init_output.status.success() {
671 eprintln!(
672 "Failed to initialize git repository for {}: {}",
673 repo,
674 String::from_utf8_lossy(&init_output.stderr)
675 );
676 return;
677 }
678 }
679 let url = format!("https://github.com/{}.git", repo);
680 util::command::new_std_command("git")
681 .current_dir(&repo_dir)
682 .args(&["remote", "add", "-f", "origin", &url])
683 .stdin(Stdio::null())
684 .output()
685 .unwrap();
686 let fetch_output = util::command::new_std_command("git")
687 .current_dir(&repo_dir)
688 .args(&["fetch", "--depth", "1", "origin", &sha])
689 .stdin(Stdio::null())
690 .output()
691 .unwrap();
692 if !fetch_output.status.success() {
693 eprintln!(
694 "Failed to fetch {} for {}: {}",
695 sha,
696 repo,
697 String::from_utf8_lossy(&fetch_output.stderr)
698 );
699 return;
700 }
701 let checkout_output = util::command::new_std_command("git")
702 .current_dir(&repo_dir)
703 .args(&["checkout", &sha])
704 .output()
705 .unwrap();
706
707 if !checkout_output.status.success() {
708 eprintln!(
709 "Failed to checkout {} for {}: {}",
710 sha,
711 repo,
712 String::from_utf8_lossy(&checkout_output.stderr)
713 );
714 }
715}