1use ::fs::{Fs, RealFs};
2use anyhow::Result;
3use clap::Parser;
4use client::{Client, UserStore};
5use clock::RealSystemClock;
6use collections::BTreeMap;
7use feature_flags::FeatureFlagAppExt as _;
8use git::GitHostingProviderRegistry;
9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
10use http_client::{HttpClient, Method};
11use language::LanguageRegistry;
12use node_runtime::FakeNodeRuntime;
13use open_ai::OpenAiEmbeddingModel;
14use project::Project;
15use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
16use serde::{Deserialize, Serialize};
17use settings::SettingsStore;
18use smol::channel::bounded;
19use smol::io::AsyncReadExt;
20use smol::Timer;
21use std::ops::RangeInclusive;
22use std::time::Duration;
23use std::{
24 fs,
25 path::Path,
26 process::{exit, Command, Stdio},
27 sync::{
28 atomic::{AtomicUsize, Ordering::SeqCst},
29 Arc,
30 },
31};
32
33const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
34const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
35const EVAL_DB_PATH: &'static str = "target/eval_db";
36const SEARCH_RESULT_LIMIT: usize = 8;
37const SKIP_EVAL_PATH: &'static str = ".skip_eval";
38
39#[derive(clap::Parser)]
40#[command(author, version, about, long_about = None)]
41struct Cli {
42 #[command(subcommand)]
43 command: Commands,
44}
45
46#[derive(clap::Subcommand)]
47enum Commands {
48 Fetch {},
49 Run {
50 #[arg(long)]
51 repo: Option<String>,
52 },
53}
54
55#[derive(Clone, Deserialize, Serialize)]
56struct EvaluationProject {
57 repo: String,
58 sha: String,
59 queries: Vec<EvaluationQuery>,
60}
61
62#[derive(Clone, Debug, Deserialize, Serialize)]
63struct EvaluationQuery {
64 query: String,
65 expected_results: Vec<EvaluationSearchResult>,
66}
67
68#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
69struct EvaluationSearchResult {
70 file: String,
71 lines: RangeInclusive<u32>,
72}
73
74#[derive(Clone, Deserialize, Serialize)]
75struct EvaluationProjectOutcome {
76 repo: String,
77 sha: String,
78 queries: Vec<EvaluationQueryOutcome>,
79}
80
81#[derive(Clone, Debug, Deserialize, Serialize)]
82struct EvaluationQueryOutcome {
83 repo: String,
84 query: String,
85 expected_results: Vec<EvaluationSearchResult>,
86 actual_results: Vec<EvaluationSearchResult>,
87 covered_file_count: usize,
88 overlapped_result_count: usize,
89 covered_result_count: usize,
90 total_result_count: usize,
91 covered_result_indices: Vec<usize>,
92}
93
94fn main() -> Result<()> {
95 let cli = Cli::parse();
96 env_logger::init();
97
98 gpui::App::headless().run(move |cx| {
99 let executor = cx.background_executor().clone();
100 let client = isahc_http_client::IsahcHttpClient::new(None, None);
101 cx.set_http_client(client.clone());
102 match cli.command {
103 Commands::Fetch {} => {
104 executor
105 .clone()
106 .spawn(async move {
107 if let Err(err) = fetch_evaluation_resources(client, &executor).await {
108 eprintln!("Error: {}", err);
109 exit(1);
110 }
111 exit(0);
112 })
113 .detach();
114 }
115 Commands::Run { repo } => {
116 cx.spawn(|mut cx| async move {
117 if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
118 eprintln!("Error: {}", err);
119 exit(1);
120 }
121 exit(0);
122 })
123 .detach();
124 }
125 }
126 });
127
128 Ok(())
129}
130
131async fn fetch_evaluation_resources(
132 http_client: Arc<dyn HttpClient>,
133 executor: &BackgroundExecutor,
134) -> Result<()> {
135 fetch_code_search_net_resources(&*http_client).await?;
136 fetch_eval_repos(executor, &*http_client).await?;
137 Ok(())
138}
139
140async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
141 eprintln!("Fetching CodeSearchNet evaluations...");
142
143 let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
144
145 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
146 fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
147
148 // Fetch the annotations CSV, which contains the human-annotated search relevances
149 let annotations_path = dataset_dir.join("annotations.csv");
150 let annotations_csv_content = if annotations_path.exists() {
151 fs::read_to_string(&annotations_path).expect("failed to read annotations")
152 } else {
153 let response = http_client
154 .get(annotations_url, Default::default(), true)
155 .await
156 .expect("failed to fetch annotations csv");
157 let mut body = String::new();
158 response
159 .into_body()
160 .read_to_string(&mut body)
161 .await
162 .expect("failed to read annotations.csv response");
163 fs::write(annotations_path, &body).expect("failed to write annotations.csv");
164 body
165 };
166
167 // Parse the annotations CSV. Skip over queries with zero relevance.
168 let rows = annotations_csv_content.lines().filter_map(|line| {
169 let mut values = line.split(',');
170 let _language = values.next()?;
171 let query = values.next()?;
172 let github_url = values.next()?;
173 let score = values.next()?;
174
175 if score == "0" {
176 return None;
177 }
178
179 let url_path = github_url.strip_prefix("https://github.com/")?;
180 let (url_path, hash) = url_path.split_once('#')?;
181 let (repo_name, url_path) = url_path.split_once("/blob/")?;
182 let (sha, file_path) = url_path.split_once('/')?;
183 let line_range = if let Some((start, end)) = hash.split_once('-') {
184 start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
185 } else {
186 let row = hash.strip_prefix("L")?.parse().ok()?;
187 row..=row
188 };
189 Some((repo_name, sha, query, file_path, line_range))
190 });
191
192 // Group the annotations by repo and sha.
193 let mut evaluations_by_repo = BTreeMap::new();
194 for (repo_name, sha, query, file_path, lines) in rows {
195 let evaluation_project = evaluations_by_repo
196 .entry((repo_name, sha))
197 .or_insert_with(|| EvaluationProject {
198 repo: repo_name.to_string(),
199 sha: sha.to_string(),
200 queries: Vec::new(),
201 });
202
203 let ix = evaluation_project
204 .queries
205 .iter()
206 .position(|entry| entry.query == query)
207 .unwrap_or_else(|| {
208 evaluation_project.queries.push(EvaluationQuery {
209 query: query.to_string(),
210 expected_results: Vec::new(),
211 });
212 evaluation_project.queries.len() - 1
213 });
214 let results = &mut evaluation_project.queries[ix].expected_results;
215 let result = EvaluationSearchResult {
216 file: file_path.to_string(),
217 lines,
218 };
219 if !results.contains(&result) {
220 results.push(result);
221 }
222 }
223
224 let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
225 let evaluations_path = dataset_dir.join("evaluations.json");
226 fs::write(
227 &evaluations_path,
228 serde_json::to_vec_pretty(&evaluations).unwrap(),
229 )
230 .unwrap();
231
232 eprintln!(
233 "Fetched CodeSearchNet evaluations into {}",
234 evaluations_path.display()
235 );
236
237 Ok(())
238}
239
240async fn run_evaluation(
241 only_repo: Option<String>,
242 executor: &BackgroundExecutor,
243 cx: &mut AsyncAppContext,
244) -> Result<()> {
245 let mut http_client = None;
246 cx.update(|cx| {
247 let mut store = SettingsStore::new(cx);
248 store
249 .set_default_settings(settings::default_settings().as_ref(), cx)
250 .unwrap();
251 cx.set_global(store);
252 client::init_settings(cx);
253 language::init(cx);
254 Project::init_settings(cx);
255 http_client = Some(cx.http_client());
256 cx.update_flags(false, vec![]);
257 })
258 .unwrap();
259 let http_client = http_client.unwrap();
260 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
261 let evaluations_path = dataset_dir.join("evaluations.json");
262 let repos_dir = Path::new(EVAL_REPOS_DIR);
263 let db_path = Path::new(EVAL_DB_PATH);
264 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
265 let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
266 let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
267 let clock = Arc::new(RealSystemClock);
268 let client = cx
269 .update(|cx| {
270 Client::new(
271 clock,
272 Arc::new(http_client::HttpClientWithUrl::new(
273 http_client.clone(),
274 "https://zed.dev",
275 None,
276 )),
277 cx,
278 )
279 })
280 .unwrap();
281 let user_store = cx
282 .new_model(|cx| UserStore::new(client.clone(), cx))
283 .unwrap();
284 let node_runtime = Arc::new(FakeNodeRuntime {});
285
286 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
287 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
288
289 let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
290 http_client.clone(),
291 OpenAiEmbeddingModel::TextEmbedding3Small,
292 open_ai::OPEN_AI_API_URL.to_string(),
293 api_key,
294 ));
295
296 let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
297 cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
298 .unwrap();
299
300 let mut covered_result_count = 0;
301 let mut overlapped_result_count = 0;
302 let mut covered_file_count = 0;
303 let mut total_result_count = 0;
304 eprint!("Running evals.");
305
306 for evaluation_project in evaluations {
307 if only_repo
308 .as_ref()
309 .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
310 {
311 continue;
312 }
313
314 eprint!("\r\x1B[2K");
315 eprint!(
316 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
317 covered_result_count,
318 total_result_count,
319 overlapped_result_count,
320 total_result_count,
321 covered_file_count,
322 total_result_count,
323 evaluation_project.repo
324 );
325
326 let repo_db_path =
327 db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
328 let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
329 .await
330 .unwrap();
331
332 let repo_dir = repos_dir.join(&evaluation_project.repo);
333 if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
334 eprintln!("Skipping {}: directory not found", evaluation_project.repo);
335 continue;
336 }
337
338 let project = cx
339 .update(|cx| {
340 Project::local(
341 client.clone(),
342 node_runtime.clone(),
343 user_store.clone(),
344 language_registry.clone(),
345 fs.clone(),
346 None,
347 cx,
348 )
349 })
350 .unwrap();
351
352 let (worktree, _) = project
353 .update(cx, |project, cx| {
354 project.find_or_create_worktree(repo_dir, true, cx)
355 })?
356 .await?;
357
358 worktree
359 .update(cx, |worktree, _| {
360 worktree.as_local().unwrap().scan_complete()
361 })
362 .unwrap()
363 .await;
364
365 let project_index = cx
366 .update(|cx| semantic_index.create_project_index(project.clone(), cx))
367 .unwrap();
368 wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
369
370 for query in evaluation_project.queries {
371 let results = cx
372 .update(|cx| {
373 let project_index = project_index.read(cx);
374 project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
375 })
376 .unwrap()
377 .await
378 .unwrap();
379
380 let results = SemanticDb::load_results(results, &fs.clone(), &cx)
381 .await
382 .unwrap();
383
384 let mut project_covered_result_count = 0;
385 let mut project_overlapped_result_count = 0;
386 let mut project_covered_file_count = 0;
387 let mut covered_result_indices = Vec::new();
388 for expected_result in &query.expected_results {
389 let mut file_matched = false;
390 let mut range_overlapped = false;
391 let mut range_covered = false;
392
393 for (ix, result) in results.iter().enumerate() {
394 if result.path.as_ref() == Path::new(&expected_result.file) {
395 file_matched = true;
396 let start_matched =
397 result.row_range.contains(&expected_result.lines.start());
398 let end_matched = result.row_range.contains(&expected_result.lines.end());
399
400 if start_matched || end_matched {
401 range_overlapped = true;
402 }
403
404 if start_matched && end_matched {
405 range_covered = true;
406 covered_result_indices.push(ix);
407 break;
408 }
409 }
410 }
411
412 if range_covered {
413 project_covered_result_count += 1
414 };
415 if range_overlapped {
416 project_overlapped_result_count += 1
417 };
418 if file_matched {
419 project_covered_file_count += 1
420 };
421 }
422 let outcome_repo = evaluation_project.repo.clone();
423
424 let query_results = EvaluationQueryOutcome {
425 repo: outcome_repo,
426 query: query.query,
427 total_result_count: query.expected_results.len(),
428 covered_result_count: project_covered_result_count,
429 overlapped_result_count: project_overlapped_result_count,
430 covered_file_count: project_covered_file_count,
431 expected_results: query.expected_results,
432 actual_results: results
433 .iter()
434 .map(|result| EvaluationSearchResult {
435 file: result.path.to_string_lossy().to_string(),
436 lines: result.row_range.clone(),
437 })
438 .collect(),
439 covered_result_indices,
440 };
441
442 overlapped_result_count += query_results.overlapped_result_count;
443 covered_result_count += query_results.covered_result_count;
444 covered_file_count += query_results.covered_file_count;
445 total_result_count += query_results.total_result_count;
446
447 println!("{}", serde_json::to_string(&query_results).unwrap());
448 }
449 }
450
451 eprint!(
452 "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
453 covered_result_count,
454 total_result_count,
455 overlapped_result_count,
456 total_result_count,
457 covered_file_count,
458 total_result_count,
459 );
460
461 Ok(())
462}
463
464async fn wait_for_indexing_complete(
465 project_index: &Model<ProjectIndex>,
466 cx: &mut AsyncAppContext,
467 timeout: Option<Duration>,
468) {
469 let (tx, rx) = bounded(1);
470 let subscription = cx.update(|cx| {
471 cx.subscribe(project_index, move |_, event, _| {
472 if let Status::Idle = event {
473 let _ = tx.try_send(*event);
474 }
475 })
476 });
477
478 let result = match timeout {
479 Some(timeout_duration) => {
480 smol::future::or(
481 async {
482 rx.recv().await.map_err(|_| ())?;
483 Ok(())
484 },
485 async {
486 Timer::after(timeout_duration).await;
487 Err(())
488 },
489 )
490 .await
491 }
492 None => rx.recv().await.map(|_| ()).map_err(|_| ()),
493 };
494
495 match result {
496 Ok(_) => (),
497 Err(_) => {
498 if let Some(timeout) = timeout {
499 eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
500 }
501 }
502 }
503
504 drop(subscription);
505}
506
507async fn fetch_eval_repos(
508 executor: &BackgroundExecutor,
509 http_client: &dyn HttpClient,
510) -> Result<()> {
511 let dataset_dir = Path::new(CODESEARCH_NET_DIR);
512 let evaluations_path = dataset_dir.join("evaluations.json");
513 let repos_dir = Path::new(EVAL_REPOS_DIR);
514
515 let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
516 let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
517
518 eprint!("Fetching evaluation repositories...");
519
520 executor
521 .scoped(move |scope| {
522 let done_count = Arc::new(AtomicUsize::new(0));
523 let len = evaluations.len();
524 for chunk in evaluations.chunks(evaluations.len() / 8) {
525 let chunk = chunk.to_vec();
526 let done_count = done_count.clone();
527 scope.spawn(async move {
528 for EvaluationProject { repo, sha, .. } in chunk {
529 eprint!(
530 "\rFetching evaluation repositories ({}/{})...",
531 done_count.load(SeqCst),
532 len,
533 );
534
535 fetch_eval_repo(repo, sha, repos_dir, http_client).await;
536 done_count.fetch_add(1, SeqCst);
537 }
538 });
539 }
540 })
541 .await;
542
543 Ok(())
544}
545
546async fn fetch_eval_repo(
547 repo: String,
548 sha: String,
549 repos_dir: &Path,
550 http_client: &dyn HttpClient,
551) {
552 let Some((owner, repo_name)) = repo.split_once('/') else {
553 return;
554 };
555 let repo_dir = repos_dir.join(owner).join(repo_name);
556 fs::create_dir_all(&repo_dir).unwrap();
557 let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
558 if skip_eval_path.exists() {
559 return;
560 }
561 if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
562 if head_content.trim() == sha {
563 return;
564 }
565 }
566 let repo_response = http_client
567 .send(
568 http_client::Request::builder()
569 .method(Method::HEAD)
570 .uri(format!("https://github.com/{}", repo))
571 .body(Default::default())
572 .expect(""),
573 )
574 .await
575 .expect("failed to check github repo");
576 if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
577 fs::write(&skip_eval_path, "").unwrap();
578 eprintln!(
579 "Repo {repo} is no longer public ({:?}). Skipping",
580 repo_response.status()
581 );
582 return;
583 }
584 if !repo_dir.join(".git").exists() {
585 let init_output = Command::new("git")
586 .current_dir(&repo_dir)
587 .args(&["init"])
588 .output()
589 .unwrap();
590 if !init_output.status.success() {
591 eprintln!(
592 "Failed to initialize git repository for {}: {}",
593 repo,
594 String::from_utf8_lossy(&init_output.stderr)
595 );
596 return;
597 }
598 }
599 let url = format!("https://github.com/{}.git", repo);
600 Command::new("git")
601 .current_dir(&repo_dir)
602 .args(&["remote", "add", "-f", "origin", &url])
603 .stdin(Stdio::null())
604 .output()
605 .unwrap();
606 let fetch_output = Command::new("git")
607 .current_dir(&repo_dir)
608 .args(&["fetch", "--depth", "1", "origin", &sha])
609 .stdin(Stdio::null())
610 .output()
611 .unwrap();
612 if !fetch_output.status.success() {
613 eprintln!(
614 "Failed to fetch {} for {}: {}",
615 sha,
616 repo,
617 String::from_utf8_lossy(&fetch_output.stderr)
618 );
619 return;
620 }
621 let checkout_output = Command::new("git")
622 .current_dir(&repo_dir)
623 .args(&["checkout", &sha])
624 .output()
625 .unwrap();
626
627 if !checkout_output.status.success() {
628 eprintln!(
629 "Failed to checkout {} for {}: {}",
630 sha,
631 repo,
632 String::from_utf8_lossy(&checkout_output.stderr)
633 );
634 }
635}