1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::FakeNodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
16use serde::{Deserialize, Serialize};
17use settings::SettingsStore;
18use smol::channel::bounded;
19use smol::io::AsyncReadExt;
20use smol::Timer;
21use std::ops::RangeInclusive;
22use std::time::Duration;
23use std::{
24 fs,
25 path::Path,
26 process::{exit, Command, Stdio},
27 sync::{
28 atomic::{AtomicUsize, Ordering::SeqCst},
29 Arc,
30 },
31};
32
33const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
34const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
35const EVAL_DB_PATH: &'static str = "target/eval_db";
36const SEARCH_RESULT_LIMIT: usize = 8;
37const SKIP_EVAL_PATH: &'static str = ".skip_eval";
38
39#[derive(clap::Parser)]
40#[command(author, version, about, long_about = None)]
41struct Cli {
42 #[command(subcommand)]
43 command: Commands,
44}
45
46#[derive(clap::Subcommand)]
47enum Commands {
48 Fetch {},
49 Run {
50 #[arg(long)]
51 repo: Option<String>,
52 },
53}
54
55#[derive(Clone, Deserialize, Serialize)]
56struct EvaluationProject {
57 repo: String,
58 sha: String,
59 queries: Vec<EvaluationQuery>,
60}
61
62#[derive(Clone, Debug, Deserialize, Serialize)]
63struct EvaluationQuery {
64 query: String,
65 expected_results: Vec<EvaluationSearchResult>,
66}
67
68#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
69struct EvaluationSearchResult {
70 file: String,
71 lines: RangeInclusive<u32>,
72}
73
74#[derive(Clone, Deserialize, Serialize)]
75struct EvaluationProjectOutcome {
76 repo: String,
77 sha: String,
78 queries: Vec<EvaluationQueryOutcome>,
79}
80
81#[derive(Clone, Debug, Deserialize, Serialize)]
82struct EvaluationQueryOutcome {
83 repo: String,
84 query: String,
85 expected_results: Vec<EvaluationSearchResult>,
86 actual_results: Vec<EvaluationSearchResult>,
87 covered_file_count: usize,
88 overlapped_result_count: usize,
89 covered_result_count: usize,
90 total_result_count: usize,
91 covered_result_indices: Vec<usize>,
92}
93
94fn main() -> Result<()> {
95 let cli = Cli::parse();
96 env_logger::init();
97
98 gpui::App::headless().run(move |cx| {
99 let executor = cx.background_executor().clone();
100 let client = isahc_http_client::IsahcHttpClient::new(None, None);
101 cx.set_http_client(client.clone());
102 match cli.command {
103 Commands::Fetch {} => {
104 executor
105 .clone()
106 .spawn(async move {
107 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
108 eprintln!("Error: {}", err);
109 exit(1);
110 }
111 exit(0);
112 })
113 .detach();
114 }
115 Commands::Run { repo } => {
116 cx.spawn(|mut cx| async move {
117 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
118 eprintln!("Error: {}", err);
119 exit(1);
120 }
121 exit(0);
122 })
123 .detach();
124 }
125 }
126 });
127
128 Ok(())
129}
130
131async fn fetch_evaluation_resources(
132 http_client: Arc<dyn HttpClient>,
133 executor: &BackgroundExecutor,
134) -> Result<()> {
135 fetch_code_search_net_resources(&*http_client).await?;
136 fetch_eval_repos(executor, &*http_client).await?;
137 Ok(())
138}
139
140async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
141 eprintln!("Fetching CodeSearchNet evaluations...");
142
143 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
144
145 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
146 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
147
148 // Fetch the annotations CSV, which contains the human-annotated search relevances
149 let annotations_path = dataset_dir.join("annotations.csv");
150 let annotations_csv_content = if annotations_path.exists() {
151 fs::read_to_string(&annotations_path).expect("failed to read annotations")
152 } else {
153 let response = http_client
154 .get(annotations_url, Default::default(), true)
155 .await
156 .expect("failed to fetch annotations csv");
157 let mut body = String::new();
158 response
159 .into_body()
160 .read_to_string(&mut body)
161 .await
162 .expect("failed to read annotations.csv response");
163 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
164 body
165 };
166
167 // Parse the annotations CSV. Skip over queries with zero relevance.
168 let rows = annotations_csv_content.lines().filter_map(|line| {
169 let mut values = line.split(',');
170 let _language = values.next()?;
171 let query = values.next()?;
172 let github_url = values.next()?;
173 let score = values.next()?;
174
175 if score == "0" {
176 return None;
177 }
178
179 let url_path = github_url.strip_prefix("https://github.com/")?;
180 let (url_path, hash) = url_path.split_once('#')?;
181 let (repo_name, url_path) = url_path.split_once("/blob/")?;
182 let (sha, file_path) = url_path.split_once('/')?;
183 let line_range = if let Some((start, end)) = hash.split_once('-') {
184 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
185 } else {
186 let row = hash.strip_prefix("L")?.parse().ok()?;
187 row..=row
188 };
189 Some((repo_name, sha, query, file_path, line_range))
190 });
191
192 // Group the annotations by repo and sha.
193 let mut evaluations_by_repo = BTreeMap::new();
194 for (repo_name, sha, query, file_path, lines) in rows {
195 let evaluation_project = evaluations_by_repo
196 .entry((repo_name, sha))
197 .or_insert_with(|| EvaluationProject {
198 repo: repo_name.to_string(),
199 sha: sha.to_string(),
200 queries: Vec::new(),
201 });
202
203 let ix = evaluation_project
204 .queries
205 .iter()
206 .position(|entry| entry.query == query)
207 .unwrap_or_else(|| {
208 evaluation_project.queries.push(EvaluationQuery {
209 query: query.to_string(),
210 expected_results: Vec::new(),
211 });
212 evaluation_project.queries.len() - 1
213 });
214 let results = &mut evaluation_project.queries[ix].expected_results;
215 let result = EvaluationSearchResult {
216 file: file_path.to_string(),
217 lines,
218 };
219 if !results.contains(&result) {
220 results.push(result);
221 }
222 }
223
224 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
225 let evaluations_path = dataset_dir.join("evaluations.json");
226 fs::write(
227 &evaluations_path,
228 serde_json::to_vec_pretty(&evaluations).unwrap(),
229 )
230 .unwrap();
231
232 eprintln!(
233 "Fetched CodeSearchNet evaluations into {}",
234 evaluations_path.display()
235 );
236
237 Ok(())
238}
239
240async fn run_evaluation(
241 only_repo: Option<String>,
242 executor: &BackgroundExecutor,
243 cx: &mut AsyncAppContext,
244) -> Result<()> {
245 let mut http_client = None;
246 cx.update(|cx| {
247 let mut store = SettingsStore::new(cx);
248 store
249 .set_default_settings(settings::default_settings().as_ref(), cx)
250 .unwrap();
251 cx.set_global(store);
252 client::init_settings(cx);
253 language::init(cx);
254 Project::init_settings(cx);
255 http_client = Some(cx.http_client());
256 cx.update_flags(false, vec![]);
257 })
258 .unwrap();
259 let http_client = http_client.unwrap();
260 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
261 let evaluations_path = dataset_dir.join("evaluations.json");
262 let repos_dir = Path::new(EVAL_REPOS_DIR);
263 let db_path = Path::new(EVAL_DB_PATH);
264 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
265 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
266 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
267 let clock = Arc::new(RealSystemClock);
268 let client = cx
269 .update(|cx| {
270 Client::new(
271 clock,
272 Arc::new(http_client::HttpClientWithUrl::new(
273 http_client.clone(),
274 "https://zed.dev",
275 None,
276 )),
277 cx,
278 )
279 })
280 .unwrap();
281 let user_store = cx
282 .new_model(|cx| UserStore::new(client.clone(), cx))
283 .unwrap();
284 let node_runtime = Arc::new(FakeNodeRuntime {});
285
286 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
287 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
288
289 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
290 http_client.clone(),
291 OpenAiEmbeddingModel::TextEmbedding3Small,
292 open_ai::OPEN_AI_API_URL.to_string(),
293 api_key,
294 ));
295
296 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
297 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
298 .unwrap();
299
300 let mut covered_result_count = 0;
301 let mut overlapped_result_count = 0;
302 let mut covered_file_count = 0;
303 let mut total_result_count = 0;
304 eprint!("Running evals.");
305
306 for evaluation_project in evaluations {
307 if only_repo
308 .as_ref()
309 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
310 {
311 continue;
312 }
313
314 eprint!("\r\x1B[2K");
315 eprint!(
316 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
317 covered_result_count,
318 total_result_count,
319 overlapped_result_count,
320 total_result_count,
321 covered_file_count,
322 total_result_count,
323 evaluation_project.repo
324 );
325
326 let repo_db_path =
327 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
328 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
329 .await
330 .unwrap();
331
332 let repo_dir = repos_dir.join(&evaluation_project.repo);
333 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
334 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
335 continue;
336 }
337
338 let project = cx
339 .update(|cx| {
340 Project::local(
341 client.clone(),
342 node_runtime.clone(),
343 user_store.clone(),
344 language_registry.clone(),
345 fs.clone(),
346 None,
347 cx,
348 )
349 })
350 .unwrap();
351
352 let (worktree, _) = project
353 .update(cx, |project, cx| {
354 project.find_or_create_worktree(repo_dir, true, cx)
355 })?
356 .await?;
357
358 worktree
359 .update(cx, |worktree, _| {
360 worktree.as_local().unwrap().scan_complete()
361 })
362 .unwrap()
363 .await;
364
365 let project_index = cx
366 .update(|cx| semantic_index.create_project_index(project.clone(), cx))
367 .unwrap();
368 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
369
370 for query in evaluation_project.queries {
371 let results = cx
372 .update(|cx| {
373 let project_index = project_index.read(cx);
374 project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
375 })
376 .unwrap()
377 .await
378 .unwrap();
379
380 let results = SemanticDb::load_results(results, &fs.clone(), &cx)
381 .await
382 .unwrap();
383
384 let mut project_covered_result_count = 0;
385 let mut project_overlapped_result_count = 0;
386 let mut project_covered_file_count = 0;
387 let mut covered_result_indices = Vec::new();
388 for expected_result in &query.expected_results {
389 let mut file_matched = false;
390 let mut range_overlapped = false;
391 let mut range_covered = false;
392
393 for (ix, result) in results.iter().enumerate() {
394 if result.path.as_ref() == Path::new(&expected_result.file) {
395 file_matched = true;
396 let start_matched =
397 result.row_range.contains(&expected_result.lines.start());
398 let end_matched = result.row_range.contains(&expected_result.lines.end());
399
400 if start_matched || end_matched {
401 range_overlapped = true;
402 }
403
404 if start_matched && end_matched {
405 range_covered = true;
406 covered_result_indices.push(ix);
407 break;
408 }
409 }
410 }
411
412 if range_covered {
413 project_covered_result_count += 1
414 };
415 if range_overlapped {
416 project_overlapped_result_count += 1
417 };
418 if file_matched {
419 project_covered_file_count += 1
420 };
421 }
422 let outcome_repo = evaluation_project.repo.clone();
423
424 let query_results = EvaluationQueryOutcome {
425 repo: outcome_repo,
426 query: query.query,
427 total_result_count: query.expected_results.len(),
428 covered_result_count: project_covered_result_count,
429 overlapped_result_count: project_overlapped_result_count,
430 covered_file_count: project_covered_file_count,
431 expected_results: query.expected_results,
432 actual_results: results
433 .iter()
434 .map(|result| EvaluationSearchResult {
435 file: result.path.to_string_lossy().to_string(),
436 lines: result.row_range.clone(),
437 })
438 .collect(),
439 covered_result_indices,
440 };
441
442 overlapped_result_count += query_results.overlapped_result_count;
443 covered_result_count += query_results.covered_result_count;
444 covered_file_count += query_results.covered_file_count;
445 total_result_count += query_results.total_result_count;
446
447 println!("{}", serde_json::to_string(&query_results).unwrap());
448 }
449
450 user_store
451 .update(cx, |_, _| {
452 drop(semantic_index);
453 drop(project);
454 drop(worktree);
455 drop(project_index);
456 })
457 .unwrap();
458 }
459
460 eprint!(
461 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
462 covered_result_count,
463 total_result_count,
464 overlapped_result_count,
465 total_result_count,
466 covered_file_count,
467 total_result_count,
468 );
469
470 Ok(())
471}
472
473async fn wait_for_indexing_complete(
474 project_index: &Model<ProjectIndex>,
475 cx: &mut AsyncAppContext,
476 timeout: Option<Duration>,
477) {
478 let (tx, rx) = bounded(1);
479 let subscription = cx.update(|cx| {
480 cx.subscribe(project_index, move |_, event, _| {
481 if let Status::Idle = event {
482 let _ = tx.try_send(*event);
483 }
484 })
485 });
486
487 let result = match timeout {
488 Some(timeout_duration) => {
489 smol::future::or(
490 async {
491 rx.recv().await.map_err(|_| ())?;
492 Ok(())
493 },
494 async {
495 Timer::after(timeout_duration).await;
496 Err(())
497 },
498 )
499 .await
500 }
501 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
502 };
503
504 match result {
505 Ok(_) => (),
506 Err(_) => {
507 if let Some(timeout) = timeout {
508 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
509 }
510 }
511 }
512
513 drop(subscription);
514}
515
516async fn fetch_eval_repos(
517 executor: &BackgroundExecutor,
518 http_client: &dyn HttpClient,
519) -> Result<()> {
520 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
521 let evaluations_path = dataset_dir.join("evaluations.json");
522 let repos_dir = Path::new(EVAL_REPOS_DIR);
523
524 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
525 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
526
527 eprint!("Fetching evaluation repositories...");
528
529 executor
530 .scoped(move |scope| {
531 let done_count = Arc::new(AtomicUsize::new(0));
532 let len = evaluations.len();
533 for chunk in evaluations.chunks(evaluations.len() / 8) {
534 let chunk = chunk.to_vec();
535 let done_count = done_count.clone();
536 scope.spawn(async move {
537 for EvaluationProject { repo, sha, .. } in chunk {
538 eprint!(
539 "\rFetching evaluation repositories ({}/{})...",
540 done_count.load(SeqCst),
541 len,
542 );
543
544 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
545 done_count.fetch_add(1, SeqCst);
546 }
547 });
548 }
549 })
550 .await;
551
552 Ok(())
553}
554
555async fn fetch_eval_repo(
556 repo: String,
557 sha: String,
558 repos_dir: &Path,
559 http_client: &dyn HttpClient,
560) {
561 let Some((owner, repo_name)) = repo.split_once('/') else {
562 return;
563 };
564 let repo_dir = repos_dir.join(owner).join(repo_name);
565 fs::create_dir_all(&repo_dir).unwrap();
566 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
567 if skip_eval_path.exists() {
568 return;
569 }
570 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
571 if head_content.trim() == sha {
572 return;
573 }
574 }
575 let repo_response = http_client
576 .send(
577 http_client::Request::builder()
578 .method(Method::HEAD)
579 .uri(format!("https://github.com/{}", repo))
580 .body(Default::default())
581 .expect(""),
582 )
583 .await
584 .expect("failed to check github repo");
585 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
586 fs::write(&skip_eval_path, "").unwrap();
587 eprintln!(
588 "Repo {repo} is no longer public ({:?}). Skipping",
589 repo_response.status()
590 );
591 return;
592 }
593 if !repo_dir.join(".git").exists() {
594 let init_output = Command::new("git")
595 .current_dir(&repo_dir)
596 .args(&["init"])
597 .output()
598 .unwrap();
599 if !init_output.status.success() {
600 eprintln!(
601 "Failed to initialize git repository for {}: {}",
602 repo,
603 String::from_utf8_lossy(&init_output.stderr)
604 );
605 return;
606 }
607 }
608 let url = format!("https://github.com/{}.git", repo);
609 Command::new("git")
610 .current_dir(&repo_dir)
611 .args(&["remote", "add", "-f", "origin", &url])
612 .stdin(Stdio::null())
613 .output()
614 .unwrap();
615 let fetch_output = Command::new("git")
616 .current_dir(&repo_dir)
617 .args(&["fetch", "--depth", "1", "origin", &sha])
618 .stdin(Stdio::null())
619 .output()
620 .unwrap();
621 if !fetch_output.status.success() {
622 eprintln!(
623 "Failed to fetch {} for {}: {}",
624 sha,
625 repo,
626 String::from_utf8_lossy(&fetch_output.stderr)
627 );
628 return;
629 }
630 let checkout_output = Command::new("git")
631 .current_dir(&repo_dir)
632 .args(&["checkout", &sha])
633 .output()
634 .unwrap();
635
636 if !checkout_output.status.success() {
637 eprintln!(
638 "Failed to checkout {} for {}: {}",
639 sha,
640 repo,
641 String::from_utf8_lossy(&checkout_output.stderr)
642 );
643 }
644}