eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use dap::DapRegistry;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::NodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use reqwest_client::ReqwestClient;
 16use semantic_index::{
 17    EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
 18};
 19use serde::{Deserialize, Serialize};
 20use settings::SettingsStore;
 21use smol::Timer;
 22use smol::channel::bounded;
 23use smol::io::AsyncReadExt;
 24use std::ops::RangeInclusive;
 25use std::path::PathBuf;
 26use std::time::Duration;
 27use std::{
 28    fs,
 29    path::Path,
 30    process::{Stdio, exit},
 31    sync::{
 32        Arc,
 33        atomic::{AtomicUsize, Ordering::SeqCst},
 34    },
 35};
 36
 37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 39const EVAL_DB_PATH: &'static str = "target/eval_db";
 40const SEARCH_RESULT_LIMIT: usize = 8;
 41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 42
 43#[derive(clap::Parser)]
 44#[command(author, version, about, long_about = None)]
 45struct Cli {
 46    #[command(subcommand)]
 47    command: Commands,
 48}
 49
 50#[derive(clap::Subcommand)]
 51enum Commands {
 52    Fetch {},
 53    Run {
 54        #[arg(long)]
 55        repo: Option<String>,
 56    },
 57}
 58
 59#[derive(Clone, Deserialize, Serialize)]
 60struct EvaluationProject {
 61    repo: String,
 62    sha: String,
 63    queries: Vec<EvaluationQuery>,
 64}
 65
 66#[derive(Clone, Debug, Deserialize, Serialize)]
 67struct EvaluationQuery {
 68    query: String,
 69    expected_results: Vec<EvaluationSearchResult>,
 70}
 71
 72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 73struct EvaluationSearchResult {
 74    file: String,
 75    lines: RangeInclusive<u32>,
 76}
 77
 78#[derive(Clone, Deserialize, Serialize)]
 79struct EvaluationProjectOutcome {
 80    repo: String,
 81    sha: String,
 82    queries: Vec<EvaluationQueryOutcome>,
 83}
 84
 85#[derive(Clone, Debug, Deserialize, Serialize)]
 86struct EvaluationQueryOutcome {
 87    repo: String,
 88    query: String,
 89    expected_results: Vec<EvaluationSearchResult>,
 90    actual_results: Vec<EvaluationSearchResult>,
 91    covered_file_count: usize,
 92    overlapped_result_count: usize,
 93    covered_result_count: usize,
 94    total_result_count: usize,
 95    covered_result_indices: Vec<usize>,
 96}
 97
 98fn main() -> Result<()> {
 99    let cli = Cli::parse();
100    env_logger::init();
101
102    gpui::Application::headless().run(move |cx| {
103        let executor = cx.background_executor().clone();
104        let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
105        cx.set_http_client(client.clone());
106        match cli.command {
107            Commands::Fetch {} => {
108                executor
109                    .clone()
110                    .spawn(async move {
111                        if let Err(err) = fetch_evaluation_resources(client, &executor).await {
112                            eprintln!("Error: {}", err);
113                            exit(1);
114                        }
115                        exit(0);
116                    })
117                    .detach();
118            }
119            Commands::Run { repo } => {
120                cx.spawn(async move |cx| {
121                    if let Err(err) = run_evaluation(repo, &executor, cx).await {
122                        eprintln!("Error: {}", err);
123                        exit(1);
124                    }
125                    exit(0);
126                })
127                .detach();
128            }
129        }
130    });
131
132    Ok(())
133}
134
135async fn fetch_evaluation_resources(
136    http_client: Arc<dyn HttpClient>,
137    executor: &BackgroundExecutor,
138) -> Result<()> {
139    fetch_code_search_net_resources(&*http_client).await?;
140    fetch_eval_repos(executor, &*http_client).await?;
141    Ok(())
142}
143
144async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
145    eprintln!("Fetching CodeSearchNet evaluations...");
146
147    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
148
149    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
150    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
151
152    // Fetch the annotations CSV, which contains the human-annotated search relevances
153    let annotations_path = dataset_dir.join("annotations.csv");
154    let annotations_csv_content = if annotations_path.exists() {
155        fs::read_to_string(&annotations_path).expect("failed to read annotations")
156    } else {
157        let response = http_client
158            .get(annotations_url, Default::default(), true)
159            .await
160            .expect("failed to fetch annotations csv");
161        let mut body = String::new();
162        response
163            .into_body()
164            .read_to_string(&mut body)
165            .await
166            .expect("failed to read annotations.csv response");
167        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
168        body
169    };
170
171    // Parse the annotations CSV. Skip over queries with zero relevance.
172    let rows = annotations_csv_content.lines().filter_map(|line| {
173        let mut values = line.split(',');
174        let _language = values.next()?;
175        let query = values.next()?;
176        let github_url = values.next()?;
177        let score = values.next()?;
178
179        if score == "0" {
180            return None;
181        }
182
183        let url_path = github_url.strip_prefix("https://github.com/")?;
184        let (url_path, hash) = url_path.split_once('#')?;
185        let (repo_name, url_path) = url_path.split_once("/blob/")?;
186        let (sha, file_path) = url_path.split_once('/')?;
187        let line_range = if let Some((start, end)) = hash.split_once('-') {
188            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
189        } else {
190            let row = hash.strip_prefix("L")?.parse().ok()?;
191            row..=row
192        };
193        Some((repo_name, sha, query, file_path, line_range))
194    });
195
196    // Group the annotations by repo and sha.
197    let mut evaluations_by_repo = BTreeMap::new();
198    for (repo_name, sha, query, file_path, lines) in rows {
199        let evaluation_project = evaluations_by_repo
200            .entry((repo_name, sha))
201            .or_insert_with(|| EvaluationProject {
202                repo: repo_name.to_string(),
203                sha: sha.to_string(),
204                queries: Vec::new(),
205            });
206
207        let ix = evaluation_project
208            .queries
209            .iter()
210            .position(|entry| entry.query == query)
211            .unwrap_or_else(|| {
212                evaluation_project.queries.push(EvaluationQuery {
213                    query: query.to_string(),
214                    expected_results: Vec::new(),
215                });
216                evaluation_project.queries.len() - 1
217            });
218        let results = &mut evaluation_project.queries[ix].expected_results;
219        let result = EvaluationSearchResult {
220            file: file_path.to_string(),
221            lines,
222        };
223        if !results.contains(&result) {
224            results.push(result);
225        }
226    }
227
228    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
229    let evaluations_path = dataset_dir.join("evaluations.json");
230    fs::write(
231        &evaluations_path,
232        serde_json::to_vec_pretty(&evaluations).unwrap(),
233    )
234    .unwrap();
235
236    eprintln!(
237        "Fetched CodeSearchNet evaluations into {}",
238        evaluations_path.display()
239    );
240
241    Ok(())
242}
243
244#[derive(Default, Debug)]
245struct Counts {
246    covered_results: usize,
247    overlapped_results: usize,
248    covered_files: usize,
249    total_results: usize,
250}
251
252async fn run_evaluation(
253    only_repo: Option<String>,
254    executor: &BackgroundExecutor,
255    cx: &mut AsyncApp,
256) -> Result<()> {
257    let mut http_client = None;
258    cx.update(|cx| {
259        let mut store = SettingsStore::new(cx);
260        store
261            .set_default_settings(settings::default_settings().as_ref(), cx)
262            .unwrap();
263        cx.set_global(store);
264        client::init_settings(cx);
265        language::init(cx);
266        Project::init_settings(cx);
267        http_client = Some(cx.http_client());
268        cx.update_flags(false, vec![]);
269    })
270    .unwrap();
271    let http_client = http_client.unwrap();
272    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
273    let evaluations_path = dataset_dir.join("evaluations.json");
274    let repos_dir = Path::new(EVAL_REPOS_DIR);
275    let db_path = Path::new(EVAL_DB_PATH);
276    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
277    let fs = Arc::new(RealFs::new(None, cx.background_executor().clone())) as Arc<dyn Fs>;
278    let clock = Arc::new(RealSystemClock);
279    let client = cx
280        .update(|cx| {
281            Client::new(
282                clock,
283                Arc::new(http_client::HttpClientWithUrl::new(
284                    http_client.clone(),
285                    "https://zed.dev",
286                    None,
287                )),
288                cx,
289            )
290        })
291        .unwrap();
292    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
293    let node_runtime = NodeRuntime::unavailable();
294
295    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
296    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
297
298    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
299        http_client.clone(),
300        OpenAiEmbeddingModel::TextEmbedding3Small,
301        open_ai::OPEN_AI_API_URL.to_string(),
302        api_key,
303    ));
304
305    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
306    let debug_adapters = Arc::new(DapRegistry::default());
307    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
308        .unwrap();
309
310    let mut counts = Counts::default();
311    eprint!("Running evals.");
312
313    let mut failures = Vec::new();
314
315    for evaluation_project in evaluations {
316        if only_repo
317            .as_ref()
318            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
319        {
320            continue;
321        }
322
323        eprint!("\r\x1B[2K");
324        eprint!(
325            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
326            counts.covered_results,
327            counts.total_results,
328            counts.overlapped_results,
329            counts.total_results,
330            counts.covered_files,
331            counts.total_results,
332            evaluation_project.repo
333        );
334
335        let repo_dir = repos_dir.join(&evaluation_project.repo);
336        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
337            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
338            continue;
339        }
340
341        let repo_db_path =
342            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
343
344        let project = cx
345            .update(|cx| {
346                Project::local(
347                    client.clone(),
348                    node_runtime.clone(),
349                    user_store.clone(),
350                    language_registry.clone(),
351                    debug_adapters.clone(),
352                    fs.clone(),
353                    None,
354                    cx,
355                )
356            })
357            .unwrap();
358
359        let repo = evaluation_project.repo.clone();
360        if let Err(err) = run_eval_project(
361            evaluation_project,
362            &user_store,
363            repo_db_path,
364            &repo_dir,
365            &mut counts,
366            project,
367            embedding_provider.clone(),
368            fs.clone(),
369            cx,
370        )
371        .await
372        {
373            eprintln!("{repo} eval failed with error: {:?}", err);
374
375            failures.push((repo, err));
376        }
377    }
378
379    eprintln!(
380        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
381        counts.covered_results,
382        counts.total_results,
383        counts.overlapped_results,
384        counts.total_results,
385        counts.covered_files,
386        counts.total_results,
387        failures.len(),
388    );
389
390    if failures.is_empty() {
391        Ok(())
392    } else {
393        eprintln!("Failures:\n");
394
395        for (index, (repo, failure)) in failures.iter().enumerate() {
396            eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
397        }
398
399        Err(anyhow::anyhow!("Some evals failed."))
400    }
401}
402
403async fn run_eval_project(
404    evaluation_project: EvaluationProject,
405    user_store: &Entity<UserStore>,
406    repo_db_path: PathBuf,
407    repo_dir: &Path,
408    counts: &mut Counts,
409    project: Entity<Project>,
410    embedding_provider: Arc<dyn EmbeddingProvider>,
411    fs: Arc<dyn Fs>,
412    cx: &mut AsyncApp,
413) -> Result<(), anyhow::Error> {
414    let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
415
416    let (worktree, _) = project
417        .update(cx, |project, cx| {
418            project.find_or_create_worktree(repo_dir, true, cx)
419        })?
420        .await?;
421
422    worktree
423        .update(cx, |worktree, _| {
424            worktree.as_local().unwrap().scan_complete()
425        })?
426        .await;
427
428    let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
429    wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
430
431    for query in evaluation_project.queries {
432        let results = {
433            // Retry search up to 3 times in case of timeout, network failure, etc.
434            let mut retries_remaining = 3;
435            let mut result;
436
437            loop {
438                match cx.update(|cx| {
439                    let project_index = project_index.read(cx);
440                    project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
441                }) {
442                    Ok(task) => match task.await {
443                        Ok(answer) => {
444                            result = Ok(answer);
445                            break;
446                        }
447                        Err(err) => {
448                            result = Err(err);
449                        }
450                    },
451                    Err(err) => {
452                        result = Err(err);
453                    }
454                }
455
456                if retries_remaining > 0 {
457                    eprintln!(
458                        "Retrying search after it failed on query {:?} with {:?}",
459                        query, result
460                    );
461                    retries_remaining -= 1;
462                } else {
463                    eprintln!(
464                        "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
465                        query, result
466                    );
467                    break;
468                }
469            }
470
471            SemanticDb::load_results(result?, &fs.clone(), &cx).await?
472        };
473
474        let mut project_covered_result_count = 0;
475        let mut project_overlapped_result_count = 0;
476        let mut project_covered_file_count = 0;
477        let mut covered_result_indices = Vec::new();
478        for expected_result in &query.expected_results {
479            let mut file_matched = false;
480            let mut range_overlapped = false;
481            let mut range_covered = false;
482
483            for (ix, result) in results.iter().enumerate() {
484                if result.path.as_ref() == Path::new(&expected_result.file) {
485                    file_matched = true;
486                    let start_matched = result.row_range.contains(expected_result.lines.start());
487                    let end_matched = result.row_range.contains(expected_result.lines.end());
488
489                    if start_matched || end_matched {
490                        range_overlapped = true;
491                    }
492
493                    if start_matched && end_matched {
494                        range_covered = true;
495                        covered_result_indices.push(ix);
496                        break;
497                    }
498                }
499            }
500
501            if range_covered {
502                project_covered_result_count += 1
503            };
504            if range_overlapped {
505                project_overlapped_result_count += 1
506            };
507            if file_matched {
508                project_covered_file_count += 1
509            };
510        }
511        let outcome_repo = evaluation_project.repo.clone();
512
513        let query_results = EvaluationQueryOutcome {
514            repo: outcome_repo,
515            query: query.query,
516            total_result_count: query.expected_results.len(),
517            covered_result_count: project_covered_result_count,
518            overlapped_result_count: project_overlapped_result_count,
519            covered_file_count: project_covered_file_count,
520            expected_results: query.expected_results,
521            actual_results: results
522                .iter()
523                .map(|result| EvaluationSearchResult {
524                    file: result.path.to_string_lossy().to_string(),
525                    lines: result.row_range.clone(),
526                })
527                .collect(),
528            covered_result_indices,
529        };
530
531        counts.overlapped_results += query_results.overlapped_result_count;
532        counts.covered_results += query_results.covered_result_count;
533        counts.covered_files += query_results.covered_file_count;
534        counts.total_results += query_results.total_result_count;
535
536        println!("{}", serde_json::to_string(&query_results)?);
537    }
538
539    user_store.update(cx, |_, _| {
540        drop(semantic_index);
541        drop(project);
542        drop(worktree);
543        drop(project_index);
544    })
545}
546
547async fn wait_for_indexing_complete(
548    project_index: &Entity<ProjectIndex>,
549    cx: &mut AsyncApp,
550    timeout: Option<Duration>,
551) {
552    let (tx, rx) = bounded(1);
553    let subscription = cx.update(|cx| {
554        cx.subscribe(project_index, move |_, event, _| {
555            if let Status::Idle = event {
556                let _ = tx.try_send(*event);
557            }
558        })
559    });
560
561    let result = match timeout {
562        Some(timeout_duration) => {
563            smol::future::or(
564                async {
565                    rx.recv().await.map_err(|_| ())?;
566                    Ok(())
567                },
568                async {
569                    Timer::after(timeout_duration).await;
570                    Err(())
571                },
572            )
573            .await
574        }
575        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
576    };
577
578    match result {
579        Ok(_) => (),
580        Err(_) => {
581            if let Some(timeout) = timeout {
582                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
583            }
584        }
585    }
586
587    drop(subscription);
588}
589
590async fn fetch_eval_repos(
591    executor: &BackgroundExecutor,
592    http_client: &dyn HttpClient,
593) -> Result<()> {
594    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
595    let evaluations_path = dataset_dir.join("evaluations.json");
596    let repos_dir = Path::new(EVAL_REPOS_DIR);
597
598    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
599    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
600
601    eprintln!("Fetching evaluation repositories...");
602
603    executor
604        .scoped(move |scope| {
605            let done_count = Arc::new(AtomicUsize::new(0));
606            let len = evaluations.len();
607            for chunk in evaluations.chunks(evaluations.len() / 8) {
608                let chunk = chunk.to_vec();
609                let done_count = done_count.clone();
610                scope.spawn(async move {
611                    for EvaluationProject { repo, sha, .. } in chunk {
612                        eprint!(
613                            "\rFetching evaluation repositories ({}/{})...",
614                            done_count.load(SeqCst),
615                            len,
616                        );
617
618                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
619                        done_count.fetch_add(1, SeqCst);
620                    }
621                });
622            }
623        })
624        .await;
625
626    Ok(())
627}
628
629async fn fetch_eval_repo(
630    repo: String,
631    sha: String,
632    repos_dir: &Path,
633    http_client: &dyn HttpClient,
634) {
635    let Some((owner, repo_name)) = repo.split_once('/') else {
636        return;
637    };
638    let repo_dir = repos_dir.join(owner).join(repo_name);
639    fs::create_dir_all(&repo_dir).unwrap();
640    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
641    if skip_eval_path.exists() {
642        return;
643    }
644    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
645        if head_content.trim() == sha {
646            return;
647        }
648    }
649    let repo_response = http_client
650        .send(
651            http_client::Request::builder()
652                .method(Method::HEAD)
653                .uri(format!("https://github.com/{}", repo))
654                .body(Default::default())
655                .expect(""),
656        )
657        .await
658        .expect("failed to check github repo");
659    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
660        fs::write(&skip_eval_path, "").unwrap();
661        eprintln!(
662            "Repo {repo} is no longer public ({:?}). Skipping",
663            repo_response.status()
664        );
665        return;
666    }
667    if !repo_dir.join(".git").exists() {
668        let init_output = util::command::new_std_command("git")
669            .current_dir(&repo_dir)
670            .args(&["init"])
671            .output()
672            .unwrap();
673        if !init_output.status.success() {
674            eprintln!(
675                "Failed to initialize git repository for {}: {}",
676                repo,
677                String::from_utf8_lossy(&init_output.stderr)
678            );
679            return;
680        }
681    }
682    let url = format!("https://github.com/{}.git", repo);
683    util::command::new_std_command("git")
684        .current_dir(&repo_dir)
685        .args(&["remote", "add", "-f", "origin", &url])
686        .stdin(Stdio::null())
687        .output()
688        .unwrap();
689    let fetch_output = util::command::new_std_command("git")
690        .current_dir(&repo_dir)
691        .args(&["fetch", "--depth", "1", "origin", &sha])
692        .stdin(Stdio::null())
693        .output()
694        .unwrap();
695    if !fetch_output.status.success() {
696        eprintln!(
697            "Failed to fetch {} for {}: {}",
698            sha,
699            repo,
700            String::from_utf8_lossy(&fetch_output.stderr)
701        );
702        return;
703    }
704    let checkout_output = util::command::new_std_command("git")
705        .current_dir(&repo_dir)
706        .args(&["checkout", &sha])
707        .output()
708        .unwrap();
709
710    if !checkout_output.status.success() {
711        eprintln!(
712            "Failed to checkout {} for {}: {}",
713            sha,
714            repo,
715            String::from_utf8_lossy(&checkout_output.stderr)
716        );
717    }
718}