1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
9use http_client::{HttpClient, Method};
10use language::LanguageRegistry;
11use node_runtime::NodeRuntime;
12use open_ai::OpenAiEmbeddingModel;
13use project::Project;
14use reqwest_client::ReqwestClient;
15use semantic_index::{
16 EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
17};
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::channel::bounded;
21use smol::io::AsyncReadExt;
22use smol::Timer;
23use std::ops::RangeInclusive;
24use std::path::PathBuf;
25use std::time::Duration;
26use std::{
27 fs,
28 path::Path,
29 process::{exit, Stdio},
30 sync::{
31 atomic::{AtomicUsize, Ordering::SeqCst},
32 Arc,
33 },
34};
35
36const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
37const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
38const EVAL_DB_PATH: &'static str = "target/eval_db";
39const SEARCH_RESULT_LIMIT: usize = 8;
40const SKIP_EVAL_PATH: &'static str = ".skip_eval";
41
42#[derive(clap::Parser)]
43#[command(author, version, about, long_about = None)]
44struct Cli {
45 #[command(subcommand)]
46 command: Commands,
47}
48
49#[derive(clap::Subcommand)]
50enum Commands {
51 Fetch {},
52 Run {
53 #[arg(long)]
54 repo: Option<String>,
55 },
56}
57
58#[derive(Clone, Deserialize, Serialize)]
59struct EvaluationProject {
60 repo: String,
61 sha: String,
62 queries: Vec<EvaluationQuery>,
63}
64
65#[derive(Clone, Debug, Deserialize, Serialize)]
66struct EvaluationQuery {
67 query: String,
68 expected_results: Vec<EvaluationSearchResult>,
69}
70
71#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72struct EvaluationSearchResult {
73 file: String,
74 lines: RangeInclusive<u32>,
75}
76
77#[derive(Clone, Deserialize, Serialize)]
78struct EvaluationProjectOutcome {
79 repo: String,
80 sha: String,
81 queries: Vec<EvaluationQueryOutcome>,
82}
83
84#[derive(Clone, Debug, Deserialize, Serialize)]
85struct EvaluationQueryOutcome {
86 repo: String,
87 query: String,
88 expected_results: Vec<EvaluationSearchResult>,
89 actual_results: Vec<EvaluationSearchResult>,
90 covered_file_count: usize,
91 overlapped_result_count: usize,
92 covered_result_count: usize,
93 total_result_count: usize,
94 covered_result_indices: Vec<usize>,
95}
96
97fn main() -> Result<()> {
98 let cli = Cli::parse();
99 env_logger::init();
100
101 gpui::Application::headless().run(move |cx| {
102 let executor = cx.background_executor().clone();
103 let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
104 cx.set_http_client(client.clone());
105 match cli.command {
106 Commands::Fetch {} => {
107 executor
108 .clone()
109 .spawn(async move {
110 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
111 eprintln!("Error: {}", err);
112 exit(1);
113 }
114 exit(0);
115 })
116 .detach();
117 }
118 Commands::Run { repo } => {
119 cx.spawn(|mut cx| async move {
120 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
121 eprintln!("Error: {}", err);
122 exit(1);
123 }
124 exit(0);
125 })
126 .detach();
127 }
128 }
129 });
130
131 Ok(())
132}
133
134async fn fetch_evaluation_resources(
135 http_client: Arc<dyn HttpClient>,
136 executor: &BackgroundExecutor,
137) -> Result<()> {
138 fetch_code_search_net_resources(&*http_client).await?;
139 fetch_eval_repos(executor, &*http_client).await?;
140 Ok(())
141}
142
143async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
144 eprintln!("Fetching CodeSearchNet evaluations...");
145
146 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
147
148 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
149 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
150
151 // Fetch the annotations CSV, which contains the human-annotated search relevances
152 let annotations_path = dataset_dir.join("annotations.csv");
153 let annotations_csv_content = if annotations_path.exists() {
154 fs::read_to_string(&annotations_path).expect("failed to read annotations")
155 } else {
156 let response = http_client
157 .get(annotations_url, Default::default(), true)
158 .await
159 .expect("failed to fetch annotations csv");
160 let mut body = String::new();
161 response
162 .into_body()
163 .read_to_string(&mut body)
164 .await
165 .expect("failed to read annotations.csv response");
166 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
167 body
168 };
169
170 // Parse the annotations CSV. Skip over queries with zero relevance.
171 let rows = annotations_csv_content.lines().filter_map(|line| {
172 let mut values = line.split(',');
173 let _language = values.next()?;
174 let query = values.next()?;
175 let github_url = values.next()?;
176 let score = values.next()?;
177
178 if score == "0" {
179 return None;
180 }
181
182 let url_path = github_url.strip_prefix("https://github.com/")?;
183 let (url_path, hash) = url_path.split_once('#')?;
184 let (repo_name, url_path) = url_path.split_once("/blob/")?;
185 let (sha, file_path) = url_path.split_once('/')?;
186 let line_range = if let Some((start, end)) = hash.split_once('-') {
187 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
188 } else {
189 let row = hash.strip_prefix("L")?.parse().ok()?;
190 row..=row
191 };
192 Some((repo_name, sha, query, file_path, line_range))
193 });
194
195 // Group the annotations by repo and sha.
196 let mut evaluations_by_repo = BTreeMap::new();
197 for (repo_name, sha, query, file_path, lines) in rows {
198 let evaluation_project = evaluations_by_repo
199 .entry((repo_name, sha))
200 .or_insert_with(|| EvaluationProject {
201 repo: repo_name.to_string(),
202 sha: sha.to_string(),
203 queries: Vec::new(),
204 });
205
206 let ix = evaluation_project
207 .queries
208 .iter()
209 .position(|entry| entry.query == query)
210 .unwrap_or_else(|| {
211 evaluation_project.queries.push(EvaluationQuery {
212 query: query.to_string(),
213 expected_results: Vec::new(),
214 });
215 evaluation_project.queries.len() - 1
216 });
217 let results = &mut evaluation_project.queries[ix].expected_results;
218 let result = EvaluationSearchResult {
219 file: file_path.to_string(),
220 lines,
221 };
222 if !results.contains(&result) {
223 results.push(result);
224 }
225 }
226
227 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
228 let evaluations_path = dataset_dir.join("evaluations.json");
229 fs::write(
230 &evaluations_path,
231 serde_json::to_vec_pretty(&evaluations).unwrap(),
232 )
233 .unwrap();
234
235 eprintln!(
236 "Fetched CodeSearchNet evaluations into {}",
237 evaluations_path.display()
238 );
239
240 Ok(())
241}
242
243#[derive(Default, Debug)]
244struct Counts {
245 covered_results: usize,
246 overlapped_results: usize,
247 covered_files: usize,
248 total_results: usize,
249}
250
251async fn run_evaluation(
252 only_repo: Option<String>,
253 executor: &BackgroundExecutor,
254 cx: &mut AsyncApp,
255) -> Result<()> {
256 let mut http_client = None;
257 cx.update(|cx| {
258 let mut store = SettingsStore::new(cx);
259 store
260 .set_default_settings(settings::default_settings().as_ref(), cx)
261 .unwrap();
262 cx.set_global(store);
263 client::init_settings(cx);
264 language::init(cx);
265 Project::init_settings(cx);
266 http_client = Some(cx.http_client());
267 cx.update_flags(false, vec![]);
268 })
269 .unwrap();
270 let http_client = http_client.unwrap();
271 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
272 let evaluations_path = dataset_dir.join("evaluations.json");
273 let repos_dir = Path::new(EVAL_REPOS_DIR);
274 let db_path = Path::new(EVAL_DB_PATH);
275 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
276 let fs = Arc::new(RealFs::new(None)) as Arc<dyn Fs>;
277 let clock = Arc::new(RealSystemClock);
278 let client = cx
279 .update(|cx| {
280 Client::new(
281 clock,
282 Arc::new(http_client::HttpClientWithUrl::new(
283 http_client.clone(),
284 "https://zed.dev",
285 None,
286 )),
287 cx,
288 )
289 })
290 .unwrap();
291 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
292 let node_runtime = NodeRuntime::unavailable();
293
294 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
295 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
296
297 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
298 http_client.clone(),
299 OpenAiEmbeddingModel::TextEmbedding3Small,
300 open_ai::OPEN_AI_API_URL.to_string(),
301 api_key,
302 ));
303
304 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
305 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
306 .unwrap();
307
308 let mut counts = Counts::default();
309 eprint!("Running evals.");
310
311 let mut failures = Vec::new();
312
313 for evaluation_project in evaluations {
314 if only_repo
315 .as_ref()
316 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
317 {
318 continue;
319 }
320
321 eprint!("\r\x1B[2K");
322 eprint!(
323 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
324 counts.covered_results,
325 counts.total_results,
326 counts.overlapped_results,
327 counts.total_results,
328 counts.covered_files,
329 counts.total_results,
330 evaluation_project.repo
331 );
332
333 let repo_dir = repos_dir.join(&evaluation_project.repo);
334 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
335 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
336 continue;
337 }
338
339 let repo_db_path =
340 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
341
342 let project = cx
343 .update(|cx| {
344 Project::local(
345 client.clone(),
346 node_runtime.clone(),
347 user_store.clone(),
348 language_registry.clone(),
349 fs.clone(),
350 None,
351 cx,
352 )
353 })
354 .unwrap();
355
356 let repo = evaluation_project.repo.clone();
357 if let Err(err) = run_eval_project(
358 evaluation_project,
359 &user_store,
360 repo_db_path,
361 &repo_dir,
362 &mut counts,
363 project,
364 embedding_provider.clone(),
365 fs.clone(),
366 cx,
367 )
368 .await
369 {
370 eprintln!("{repo} eval failed with error: {:?}", err);
371
372 failures.push((repo, err));
373 }
374 }
375
376 eprintln!(
377 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
378 counts.covered_results,
379 counts.total_results,
380 counts.overlapped_results,
381 counts.total_results,
382 counts.covered_files,
383 counts.total_results,
384 failures.len(),
385 );
386
387 if failures.is_empty() {
388 Ok(())
389 } else {
390 eprintln!("Failures:\n");
391
392 for (index, (repo, failure)) in failures.iter().enumerate() {
393 eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
394 }
395
396 Err(anyhow::anyhow!("Some evals failed."))
397 }
398}
399
400#[allow(clippy::too_many_arguments)]
401async fn run_eval_project(
402 evaluation_project: EvaluationProject,
403 user_store: &Entity<UserStore>,
404 repo_db_path: PathBuf,
405 repo_dir: &Path,
406 counts: &mut Counts,
407 project: Entity<Project>,
408 embedding_provider: Arc<dyn EmbeddingProvider>,
409 fs: Arc<dyn Fs>,
410 cx: &mut AsyncApp,
411) -> Result<(), anyhow::Error> {
412 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
413
414 let (worktree, _) = project
415 .update(cx, |project, cx| {
416 project.find_or_create_worktree(repo_dir, true, cx)
417 })?
418 .await?;
419
420 worktree
421 .update(cx, |worktree, _| {
422 worktree.as_local().unwrap().scan_complete()
423 })?
424 .await;
425
426 let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
427 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
428
429 for query in evaluation_project.queries {
430 let results = {
431 // Retry search up to 3 times in case of timeout, network failure, etc.
432 let mut retries_remaining = 3;
433 let mut result;
434
435 loop {
436 match cx.update(|cx| {
437 let project_index = project_index.read(cx);
438 project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
439 }) {
440 Ok(task) => match task.await {
441 Ok(answer) => {
442 result = Ok(answer);
443 break;
444 }
445 Err(err) => {
446 result = Err(err);
447 }
448 },
449 Err(err) => {
450 result = Err(err);
451 }
452 }
453
454 if retries_remaining > 0 {
455 eprintln!(
456 "Retrying search after it failed on query {:?} with {:?}",
457 query, result
458 );
459 retries_remaining -= 1;
460 } else {
461 eprintln!(
462 "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
463 query, result
464 );
465 break;
466 }
467 }
468
469 SemanticDb::load_results(result?, &fs.clone(), &cx).await?
470 };
471
472 let mut project_covered_result_count = 0;
473 let mut project_overlapped_result_count = 0;
474 let mut project_covered_file_count = 0;
475 let mut covered_result_indices = Vec::new();
476 for expected_result in &query.expected_results {
477 let mut file_matched = false;
478 let mut range_overlapped = false;
479 let mut range_covered = false;
480
481 for (ix, result) in results.iter().enumerate() {
482 if result.path.as_ref() == Path::new(&expected_result.file) {
483 file_matched = true;
484 let start_matched = result.row_range.contains(&expected_result.lines.start());
485 let end_matched = result.row_range.contains(&expected_result.lines.end());
486
487 if start_matched || end_matched {
488 range_overlapped = true;
489 }
490
491 if start_matched && end_matched {
492 range_covered = true;
493 covered_result_indices.push(ix);
494 break;
495 }
496 }
497 }
498
499 if range_covered {
500 project_covered_result_count += 1
501 };
502 if range_overlapped {
503 project_overlapped_result_count += 1
504 };
505 if file_matched {
506 project_covered_file_count += 1
507 };
508 }
509 let outcome_repo = evaluation_project.repo.clone();
510
511 let query_results = EvaluationQueryOutcome {
512 repo: outcome_repo,
513 query: query.query,
514 total_result_count: query.expected_results.len(),
515 covered_result_count: project_covered_result_count,
516 overlapped_result_count: project_overlapped_result_count,
517 covered_file_count: project_covered_file_count,
518 expected_results: query.expected_results,
519 actual_results: results
520 .iter()
521 .map(|result| EvaluationSearchResult {
522 file: result.path.to_string_lossy().to_string(),
523 lines: result.row_range.clone(),
524 })
525 .collect(),
526 covered_result_indices,
527 };
528
529 counts.overlapped_results += query_results.overlapped_result_count;
530 counts.covered_results += query_results.covered_result_count;
531 counts.covered_files += query_results.covered_file_count;
532 counts.total_results += query_results.total_result_count;
533
534 println!("{}", serde_json::to_string(&query_results)?);
535 }
536
537 user_store.update(cx, |_, _| {
538 drop(semantic_index);
539 drop(project);
540 drop(worktree);
541 drop(project_index);
542 })
543}
544
545async fn wait_for_indexing_complete(
546 project_index: &Entity<ProjectIndex>,
547 cx: &mut AsyncApp,
548 timeout: Option<Duration>,
549) {
550 let (tx, rx) = bounded(1);
551 let subscription = cx.update(|cx| {
552 cx.subscribe(project_index, move |_, event, _| {
553 if let Status::Idle = event {
554 let _ = tx.try_send(*event);
555 }
556 })
557 });
558
559 let result = match timeout {
560 Some(timeout_duration) => {
561 smol::future::or(
562 async {
563 rx.recv().await.map_err(|_| ())?;
564 Ok(())
565 },
566 async {
567 Timer::after(timeout_duration).await;
568 Err(())
569 },
570 )
571 .await
572 }
573 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
574 };
575
576 match result {
577 Ok(_) => (),
578 Err(_) => {
579 if let Some(timeout) = timeout {
580 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
581 }
582 }
583 }
584
585 drop(subscription);
586}
587
588async fn fetch_eval_repos(
589 executor: &BackgroundExecutor,
590 http_client: &dyn HttpClient,
591) -> Result<()> {
592 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
593 let evaluations_path = dataset_dir.join("evaluations.json");
594 let repos_dir = Path::new(EVAL_REPOS_DIR);
595
596 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
597 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
598
599 eprintln!("Fetching evaluation repositories...");
600
601 executor
602 .scoped(move |scope| {
603 let done_count = Arc::new(AtomicUsize::new(0));
604 let len = evaluations.len();
605 for chunk in evaluations.chunks(evaluations.len() / 8) {
606 let chunk = chunk.to_vec();
607 let done_count = done_count.clone();
608 scope.spawn(async move {
609 for EvaluationProject { repo, sha, .. } in chunk {
610 eprint!(
611 "\rFetching evaluation repositories ({}/{})...",
612 done_count.load(SeqCst),
613 len,
614 );
615
616 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
617 done_count.fetch_add(1, SeqCst);
618 }
619 });
620 }
621 })
622 .await;
623
624 Ok(())
625}
626
627async fn fetch_eval_repo(
628 repo: String,
629 sha: String,
630 repos_dir: &Path,
631 http_client: &dyn HttpClient,
632) {
633 let Some((owner, repo_name)) = repo.split_once('/') else {
634 return;
635 };
636 let repo_dir = repos_dir.join(owner).join(repo_name);
637 fs::create_dir_all(&repo_dir).unwrap();
638 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
639 if skip_eval_path.exists() {
640 return;
641 }
642 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
643 if head_content.trim() == sha {
644 return;
645 }
646 }
647 let repo_response = http_client
648 .send(
649 http_client::Request::builder()
650 .method(Method::HEAD)
651 .uri(format!("https://github.com/{}", repo))
652 .body(Default::default())
653 .expect(""),
654 )
655 .await
656 .expect("failed to check github repo");
657 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
658 fs::write(&skip_eval_path, "").unwrap();
659 eprintln!(
660 "Repo {repo} is no longer public ({:?}). Skipping",
661 repo_response.status()
662 );
663 return;
664 }
665 if !repo_dir.join(".git").exists() {
666 let init_output = util::command::new_std_command("git")
667 .current_dir(&repo_dir)
668 .args(&["init"])
669 .output()
670 .unwrap();
671 if !init_output.status.success() {
672 eprintln!(
673 "Failed to initialize git repository for {}: {}",
674 repo,
675 String::from_utf8_lossy(&init_output.stderr)
676 );
677 return;
678 }
679 }
680 let url = format!("https://github.com/{}.git", repo);
681 util::command::new_std_command("git")
682 .current_dir(&repo_dir)
683 .args(&["remote", "add", "-f", "origin", &url])
684 .stdin(Stdio::null())
685 .output()
686 .unwrap();
687 let fetch_output = util::command::new_std_command("git")
688 .current_dir(&repo_dir)
689 .args(&["fetch", "--depth", "1", "origin", &sha])
690 .stdin(Stdio::null())
691 .output()
692 .unwrap();
693 if !fetch_output.status.success() {
694 eprintln!(
695 "Failed to fetch {} for {}: {}",
696 sha,
697 repo,
698 String::from_utf8_lossy(&fetch_output.stderr)
699 );
700 return;
701 }
702 let checkout_output = util::command::new_std_command("git")
703 .current_dir(&repo_dir)
704 .args(&["checkout", &sha])
705 .output()
706 .unwrap();
707
708 if !checkout_output.status.success() {
709 eprintln!(
710 "Failed to checkout {} for {}: {}",
711 sha,
712 repo,
713 String::from_utf8_lossy(&checkout_output.stderr)
714 );
715 }
716}