eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use feature_flags::FeatureFlagAppExt as _;
  8use git::GitHostingProviderRegistry;
  9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::FakeNodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
 16use serde::{Deserialize, Serialize};
 17use settings::SettingsStore;
 18use smol::channel::bounded;
 19use smol::io::AsyncReadExt;
 20use smol::Timer;
 21use std::ops::RangeInclusive;
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::Path,
 26    process::{exit, Command, Stdio},
 27    sync::{
 28        atomic::{AtomicUsize, Ordering::SeqCst},
 29        Arc,
 30    },
 31};
 32
 33const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 34const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 35const EVAL_DB_PATH: &'static str = "target/eval_db";
 36const SEARCH_RESULT_LIMIT: usize = 8;
 37const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 38
 39#[derive(clap::Parser)]
 40#[command(author, version, about, long_about = None)]
 41struct Cli {
 42    #[command(subcommand)]
 43    command: Commands,
 44}
 45
 46#[derive(clap::Subcommand)]
 47enum Commands {
 48    Fetch {},
 49    Run {
 50        #[arg(long)]
 51        repo: Option<String>,
 52    },
 53}
 54
 55#[derive(Clone, Deserialize, Serialize)]
 56struct EvaluationProject {
 57    repo: String,
 58    sha: String,
 59    queries: Vec<EvaluationQuery>,
 60}
 61
 62#[derive(Clone, Debug, Deserialize, Serialize)]
 63struct EvaluationQuery {
 64    query: String,
 65    expected_results: Vec<EvaluationSearchResult>,
 66}
 67
 68#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 69struct EvaluationSearchResult {
 70    file: String,
 71    lines: RangeInclusive<u32>,
 72}
 73
 74#[derive(Clone, Deserialize, Serialize)]
 75struct EvaluationProjectOutcome {
 76    repo: String,
 77    sha: String,
 78    queries: Vec<EvaluationQueryOutcome>,
 79}
 80
 81#[derive(Clone, Debug, Deserialize, Serialize)]
 82struct EvaluationQueryOutcome {
 83    repo: String,
 84    query: String,
 85    expected_results: Vec<EvaluationSearchResult>,
 86    actual_results: Vec<EvaluationSearchResult>,
 87    covered_file_count: usize,
 88    overlapped_result_count: usize,
 89    covered_result_count: usize,
 90    total_result_count: usize,
 91    covered_result_indices: Vec<usize>,
 92}
 93
 94fn main() -> Result<()> {
 95    let cli = Cli::parse();
 96    env_logger::init();
 97
 98    gpui::App::headless().run(move |cx| {
 99        let executor = cx.background_executor().clone();
100        let client = isahc_http_client::IsahcHttpClient::new(None, None);
101        cx.set_http_client(client.clone());
102        match cli.command {
103            Commands::Fetch {} => {
104                executor
105                    .clone()
106                    .spawn(async move {
107                        if let Err(err) = fetch_evaluation_resources(client, &executor).await {
108                            eprintln!("Error: {}", err);
109                            exit(1);
110                        }
111                        exit(0);
112                    })
113                    .detach();
114            }
115            Commands::Run { repo } => {
116                cx.spawn(|mut cx| async move {
117                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
118                        eprintln!("Error: {}", err);
119                        exit(1);
120                    }
121                    exit(0);
122                })
123                .detach();
124            }
125        }
126    });
127
128    Ok(())
129}
130
131async fn fetch_evaluation_resources(
132    http_client: Arc<dyn HttpClient>,
133    executor: &BackgroundExecutor,
134) -> Result<()> {
135    fetch_code_search_net_resources(&*http_client).await?;
136    fetch_eval_repos(executor, &*http_client).await?;
137    Ok(())
138}
139
140async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
141    eprintln!("Fetching CodeSearchNet evaluations...");
142
143    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
144
145    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
146    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
147
148    // Fetch the annotations CSV, which contains the human-annotated search relevances
149    let annotations_path = dataset_dir.join("annotations.csv");
150    let annotations_csv_content = if annotations_path.exists() {
151        fs::read_to_string(&annotations_path).expect("failed to read annotations")
152    } else {
153        let response = http_client
154            .get(annotations_url, Default::default(), true)
155            .await
156            .expect("failed to fetch annotations csv");
157        let mut body = String::new();
158        response
159            .into_body()
160            .read_to_string(&mut body)
161            .await
162            .expect("failed to read annotations.csv response");
163        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
164        body
165    };
166
167    // Parse the annotations CSV. Skip over queries with zero relevance.
168    let rows = annotations_csv_content.lines().filter_map(|line| {
169        let mut values = line.split(',');
170        let _language = values.next()?;
171        let query = values.next()?;
172        let github_url = values.next()?;
173        let score = values.next()?;
174
175        if score == "0" {
176            return None;
177        }
178
179        let url_path = github_url.strip_prefix("https://github.com/")?;
180        let (url_path, hash) = url_path.split_once('#')?;
181        let (repo_name, url_path) = url_path.split_once("/blob/")?;
182        let (sha, file_path) = url_path.split_once('/')?;
183        let line_range = if let Some((start, end)) = hash.split_once('-') {
184            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
185        } else {
186            let row = hash.strip_prefix("L")?.parse().ok()?;
187            row..=row
188        };
189        Some((repo_name, sha, query, file_path, line_range))
190    });
191
192    // Group the annotations by repo and sha.
193    let mut evaluations_by_repo = BTreeMap::new();
194    for (repo_name, sha, query, file_path, lines) in rows {
195        let evaluation_project = evaluations_by_repo
196            .entry((repo_name, sha))
197            .or_insert_with(|| EvaluationProject {
198                repo: repo_name.to_string(),
199                sha: sha.to_string(),
200                queries: Vec::new(),
201            });
202
203        let ix = evaluation_project
204            .queries
205            .iter()
206            .position(|entry| entry.query == query)
207            .unwrap_or_else(|| {
208                evaluation_project.queries.push(EvaluationQuery {
209                    query: query.to_string(),
210                    expected_results: Vec::new(),
211                });
212                evaluation_project.queries.len() - 1
213            });
214        let results = &mut evaluation_project.queries[ix].expected_results;
215        let result = EvaluationSearchResult {
216            file: file_path.to_string(),
217            lines,
218        };
219        if !results.contains(&result) {
220            results.push(result);
221        }
222    }
223
224    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
225    let evaluations_path = dataset_dir.join("evaluations.json");
226    fs::write(
227        &evaluations_path,
228        serde_json::to_vec_pretty(&evaluations).unwrap(),
229    )
230    .unwrap();
231
232    eprintln!(
233        "Fetched CodeSearchNet evaluations into {}",
234        evaluations_path.display()
235    );
236
237    Ok(())
238}
239
240async fn run_evaluation(
241    only_repo: Option<String>,
242    executor: &BackgroundExecutor,
243    cx: &mut AsyncAppContext,
244) -> Result<()> {
245    let mut http_client = None;
246    cx.update(|cx| {
247        let mut store = SettingsStore::new(cx);
248        store
249            .set_default_settings(settings::default_settings().as_ref(), cx)
250            .unwrap();
251        cx.set_global(store);
252        client::init_settings(cx);
253        language::init(cx);
254        Project::init_settings(cx);
255        http_client = Some(cx.http_client());
256        cx.update_flags(false, vec![]);
257    })
258    .unwrap();
259    let http_client = http_client.unwrap();
260    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
261    let evaluations_path = dataset_dir.join("evaluations.json");
262    let repos_dir = Path::new(EVAL_REPOS_DIR);
263    let db_path = Path::new(EVAL_DB_PATH);
264    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
265    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
266    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
267    let clock = Arc::new(RealSystemClock);
268    let client = cx
269        .update(|cx| {
270            Client::new(
271                clock,
272                Arc::new(http_client::HttpClientWithUrl::new(
273                    http_client.clone(),
274                    "https://zed.dev",
275                    None,
276                )),
277                cx,
278            )
279        })
280        .unwrap();
281    let user_store = cx
282        .new_model(|cx| UserStore::new(client.clone(), cx))
283        .unwrap();
284    let node_runtime = Arc::new(FakeNodeRuntime {});
285
286    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
287    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
288
289    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
290        http_client.clone(),
291        OpenAiEmbeddingModel::TextEmbedding3Small,
292        open_ai::OPEN_AI_API_URL.to_string(),
293        api_key,
294    ));
295
296    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
297    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
298        .unwrap();
299
300    let mut covered_result_count = 0;
301    let mut overlapped_result_count = 0;
302    let mut covered_file_count = 0;
303    let mut total_result_count = 0;
304    eprint!("Running evals.");
305
306    for evaluation_project in evaluations {
307        if only_repo
308            .as_ref()
309            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
310        {
311            continue;
312        }
313
314        eprint!("\r\x1B[2K");
315        eprint!(
316            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
317            covered_result_count,
318            total_result_count,
319            overlapped_result_count,
320            total_result_count,
321            covered_file_count,
322            total_result_count,
323            evaluation_project.repo
324        );
325
326        let repo_db_path =
327            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
328        let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
329            .await
330            .unwrap();
331
332        let repo_dir = repos_dir.join(&evaluation_project.repo);
333        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
334            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
335            continue;
336        }
337
338        let project = cx
339            .update(|cx| {
340                Project::local(
341                    client.clone(),
342                    node_runtime.clone(),
343                    user_store.clone(),
344                    language_registry.clone(),
345                    fs.clone(),
346                    None,
347                    cx,
348                )
349            })
350            .unwrap();
351
352        let (worktree, _) = project
353            .update(cx, |project, cx| {
354                project.find_or_create_worktree(repo_dir, true, cx)
355            })?
356            .await?;
357
358        worktree
359            .update(cx, |worktree, _| {
360                worktree.as_local().unwrap().scan_complete()
361            })
362            .unwrap()
363            .await;
364
365        let project_index = cx
366            .update(|cx| semantic_index.create_project_index(project.clone(), cx))
367            .unwrap();
368        wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
369
370        for query in evaluation_project.queries {
371            let results = cx
372                .update(|cx| {
373                    let project_index = project_index.read(cx);
374                    project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
375                })
376                .unwrap()
377                .await
378                .unwrap();
379
380            let results = SemanticDb::load_results(results, &fs.clone(), &cx)
381                .await
382                .unwrap();
383
384            let mut project_covered_result_count = 0;
385            let mut project_overlapped_result_count = 0;
386            let mut project_covered_file_count = 0;
387            let mut covered_result_indices = Vec::new();
388            for expected_result in &query.expected_results {
389                let mut file_matched = false;
390                let mut range_overlapped = false;
391                let mut range_covered = false;
392
393                for (ix, result) in results.iter().enumerate() {
394                    if result.path.as_ref() == Path::new(&expected_result.file) {
395                        file_matched = true;
396                        let start_matched =
397                            result.row_range.contains(&expected_result.lines.start());
398                        let end_matched = result.row_range.contains(&expected_result.lines.end());
399
400                        if start_matched || end_matched {
401                            range_overlapped = true;
402                        }
403
404                        if start_matched && end_matched {
405                            range_covered = true;
406                            covered_result_indices.push(ix);
407                            break;
408                        }
409                    }
410                }
411
412                if range_covered {
413                    project_covered_result_count += 1
414                };
415                if range_overlapped {
416                    project_overlapped_result_count += 1
417                };
418                if file_matched {
419                    project_covered_file_count += 1
420                };
421            }
422            let outcome_repo = evaluation_project.repo.clone();
423
424            let query_results = EvaluationQueryOutcome {
425                repo: outcome_repo,
426                query: query.query,
427                total_result_count: query.expected_results.len(),
428                covered_result_count: project_covered_result_count,
429                overlapped_result_count: project_overlapped_result_count,
430                covered_file_count: project_covered_file_count,
431                expected_results: query.expected_results,
432                actual_results: results
433                    .iter()
434                    .map(|result| EvaluationSearchResult {
435                        file: result.path.to_string_lossy().to_string(),
436                        lines: result.row_range.clone(),
437                    })
438                    .collect(),
439                covered_result_indices,
440            };
441
442            overlapped_result_count += query_results.overlapped_result_count;
443            covered_result_count += query_results.covered_result_count;
444            covered_file_count += query_results.covered_file_count;
445            total_result_count += query_results.total_result_count;
446
447            println!("{}", serde_json::to_string(&query_results).unwrap());
448        }
449
450        user_store
451            .update(cx, |_, _| {
452                drop(semantic_index);
453                drop(project);
454                drop(worktree);
455                drop(project_index);
456            })
457            .unwrap();
458    }
459
460    eprint!(
461        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
462        covered_result_count,
463        total_result_count,
464        overlapped_result_count,
465        total_result_count,
466        covered_file_count,
467        total_result_count,
468    );
469
470    Ok(())
471}
472
473async fn wait_for_indexing_complete(
474    project_index: &Model<ProjectIndex>,
475    cx: &mut AsyncAppContext,
476    timeout: Option<Duration>,
477) {
478    let (tx, rx) = bounded(1);
479    let subscription = cx.update(|cx| {
480        cx.subscribe(project_index, move |_, event, _| {
481            if let Status::Idle = event {
482                let _ = tx.try_send(*event);
483            }
484        })
485    });
486
487    let result = match timeout {
488        Some(timeout_duration) => {
489            smol::future::or(
490                async {
491                    rx.recv().await.map_err(|_| ())?;
492                    Ok(())
493                },
494                async {
495                    Timer::after(timeout_duration).await;
496                    Err(())
497                },
498            )
499            .await
500        }
501        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
502    };
503
504    match result {
505        Ok(_) => (),
506        Err(_) => {
507            if let Some(timeout) = timeout {
508                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
509            }
510        }
511    }
512
513    drop(subscription);
514}
515
516async fn fetch_eval_repos(
517    executor: &BackgroundExecutor,
518    http_client: &dyn HttpClient,
519) -> Result<()> {
520    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
521    let evaluations_path = dataset_dir.join("evaluations.json");
522    let repos_dir = Path::new(EVAL_REPOS_DIR);
523
524    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
525    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
526
527    eprint!("Fetching evaluation repositories...");
528
529    executor
530        .scoped(move |scope| {
531            let done_count = Arc::new(AtomicUsize::new(0));
532            let len = evaluations.len();
533            for chunk in evaluations.chunks(evaluations.len() / 8) {
534                let chunk = chunk.to_vec();
535                let done_count = done_count.clone();
536                scope.spawn(async move {
537                    for EvaluationProject { repo, sha, .. } in chunk {
538                        eprint!(
539                            "\rFetching evaluation repositories ({}/{})...",
540                            done_count.load(SeqCst),
541                            len,
542                        );
543
544                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
545                        done_count.fetch_add(1, SeqCst);
546                    }
547                });
548            }
549        })
550        .await;
551
552    Ok(())
553}
554
555async fn fetch_eval_repo(
556    repo: String,
557    sha: String,
558    repos_dir: &Path,
559    http_client: &dyn HttpClient,
560) {
561    let Some((owner, repo_name)) = repo.split_once('/') else {
562        return;
563    };
564    let repo_dir = repos_dir.join(owner).join(repo_name);
565    fs::create_dir_all(&repo_dir).unwrap();
566    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
567    if skip_eval_path.exists() {
568        return;
569    }
570    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
571        if head_content.trim() == sha {
572            return;
573        }
574    }
575    let repo_response = http_client
576        .send(
577            http_client::Request::builder()
578                .method(Method::HEAD)
579                .uri(format!("https://github.com/{}", repo))
580                .body(Default::default())
581                .expect(""),
582        )
583        .await
584        .expect("failed to check github repo");
585    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
586        fs::write(&skip_eval_path, "").unwrap();
587        eprintln!(
588            "Repo {repo} is no longer public ({:?}). Skipping",
589            repo_response.status()
590        );
591        return;
592    }
593    if !repo_dir.join(".git").exists() {
594        let init_output = Command::new("git")
595            .current_dir(&repo_dir)
596            .args(&["init"])
597            .output()
598            .unwrap();
599        if !init_output.status.success() {
600            eprintln!(
601                "Failed to initialize git repository for {}: {}",
602                repo,
603                String::from_utf8_lossy(&init_output.stderr)
604            );
605            return;
606        }
607    }
608    let url = format!("https://github.com/{}.git", repo);
609    Command::new("git")
610        .current_dir(&repo_dir)
611        .args(&["remote", "add", "-f", "origin", &url])
612        .stdin(Stdio::null())
613        .output()
614        .unwrap();
615    let fetch_output = Command::new("git")
616        .current_dir(&repo_dir)
617        .args(&["fetch", "--depth", "1", "origin", &sha])
618        .stdin(Stdio::null())
619        .output()
620        .unwrap();
621    if !fetch_output.status.success() {
622        eprintln!(
623            "Failed to fetch {} for {}: {}",
624            sha,
625            repo,
626            String::from_utf8_lossy(&fetch_output.stderr)
627        );
628        return;
629    }
630    let checkout_output = Command::new("git")
631        .current_dir(&repo_dir)
632        .args(&["checkout", &sha])
633        .output()
634        .unwrap();
635
636    if !checkout_output.status.success() {
637        eprintln!(
638            "Failed to checkout {} for {}: {}",
639            sha,
640            repo,
641            String::from_utf8_lossy(&checkout_output.stderr)
642        );
643    }
644}